Skip to main content

datafusion_functions_nested/
array_any_match.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`datafusion_expr::HigherOrderUDF`] definitions for array_any_match function.
19
20use arrow::{
21    array::{Array, AsArray, BooleanArray, BooleanBuilder, new_null_array},
22    buffer::NullBuffer,
23    compute::take_arrays,
24    datatypes::{ArrowNativeType, DataType, Field, FieldRef},
25};
26use datafusion_common::{
27    Result, exec_datafusion_err, exec_err, plan_err,
28    utils::{
29        adjust_offsets_for_slice, list_values, list_values_row_number, take_function_args,
30    },
31};
32use datafusion_expr::{
33    ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs,
34    HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda,
35    Volatility,
36};
37use datafusion_macros::user_doc;
38use std::{fmt::Debug, sync::Arc};
39
40make_higher_order_function_expr_and_func!(
41    ArrayAnyMatch,
42    array_any_match,
43    array lambda,
44    "returns true if any element in the array satisfies the predicate",
45    array_any_match_higher_order_function
46);
47
48#[user_doc(
49    doc_section(label = "Array Functions"),
50    description = "Returns whether any elements of an array match the given predicate. Returns true if one or more elements match, false if none match (including empty arrays), and null if the predicate returns null for some elements and false for all others.",
51    syntax_example = "any_match(array, predicate)",
52    sql_example = r#"```sql
53> select any_match([1, 2, 3], x -> x > 2);
54+----------------------------------+
55| any_match([1, 2, 3], x -> x > 2) |
56+----------------------------------+
57| true                             |
58+----------------------------------+
59```"#,
60    argument(
61        name = "array",
62        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
63    ),
64    argument(
65        name = "predicate",
66        description = "Lambda predicate that returns a boolean"
67    )
68)]
69#[derive(Debug, PartialEq, Eq, Hash)]
70pub struct ArrayAnyMatch {
71    signature: HigherOrderSignature,
72    aliases: Vec<String>,
73}
74
75impl Default for ArrayAnyMatch {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl ArrayAnyMatch {
82    pub fn new() -> Self {
83        Self {
84            signature: HigherOrderSignature::exact(
85                vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
86                Volatility::Immutable,
87            ),
88            aliases: vec![String::from("any_match"), String::from("list_any_match")],
89        }
90    }
91}
92
93// Returns Some(true) if any element in [start, end) is true,
94// None if no element is true but some are null,
95// Some(false) if all are false or range is empty.
96fn any_match_for_range(
97    predicate: &BooleanArray,
98    start: usize,
99    end: usize,
100) -> Option<bool> {
101    let any_true = (start..end).any(|j| predicate.is_valid(j) && predicate.value(j));
102    if any_true {
103        return Some(true);
104    }
105    let any_null = (start..end).any(|j| predicate.is_null(j));
106    if any_null { None } else { Some(false) }
107}
108
109impl HigherOrderUDFImpl for ArrayAnyMatch {
110    fn name(&self) -> &str {
111        "array_any_match"
112    }
113
114    fn aliases(&self) -> &[String] {
115        &self.aliases
116    }
117
118    fn signature(&self) -> &HigherOrderSignature {
119        &self.signature
120    }
121
122    fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
123        let [list] = arg_types else {
124            return plan_err!(
125                "{} function requires 1 value argument, got {}",
126                self.name(),
127                arg_types.len()
128            );
129        };
130
131        let coerced = match list {
132            DataType::List(_) | DataType::LargeList(_) => list.clone(),
133            DataType::ListView(field) | DataType::FixedSizeList(field, _) => {
134                DataType::List(Arc::clone(field))
135            }
136            DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)),
137            _ => {
138                return plan_err!(
139                    "{} expected a list as first argument, got {}",
140                    self.name(),
141                    list
142                );
143            }
144        };
145
146        Ok(vec![coerced])
147    }
148
149    fn lambda_parameters(
150        &self,
151        _step: usize,
152        fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
153    ) -> Result<LambdaParametersProgress> {
154        let [list, _] = take_function_args(self.name(), fields)?;
155        let ValueOrLambda::Value(list) = list else {
156            return plan_err!("{} expects a value as first argument", self.name());
157        };
158
159        let field = match list.data_type() {
160            DataType::List(field) => field,
161            DataType::LargeList(field) => field,
162            other => return plan_err!("expected list, got {other}"),
163        };
164
165        Ok(LambdaParametersProgress::Complete(vec![vec![Arc::clone(
166            field,
167        )]]))
168    }
169
170    fn return_field_from_args(
171        &self,
172        args: HigherOrderReturnFieldArgs,
173    ) -> Result<Arc<Field>> {
174        let [ValueOrLambda::Value(list), _] =
175            take_function_args(self.name(), args.arg_fields)?
176        else {
177            return plan_err!("{} expects a value as first argument", self.name());
178        };
179        let nullable = list.is_nullable();
180        Ok(Arc::new(Field::new("", DataType::Boolean, nullable)))
181    }
182
183    fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
184        let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] =
185            take_function_args(self.name(), &args.args)?
186        else {
187            return exec_err!("{} expects a value followed by a lambda", self.name());
188        };
189
190        let list_array = list.to_array(args.number_rows)?;
191
192        // fast path: fully null input — also required for FixedSizeList which can't be
193        // handled by clear_null_values when fully null
194        if list_array.null_count() == list_array.len() {
195            return Ok(ColumnarValue::Array(new_null_array(
196                args.return_type(),
197                list_array.len(),
198            )));
199        }
200
201        let list_values = list_values(&list_array)?;
202
203        let values_param = || Ok(Arc::clone(&list_values));
204
205        let predicate_results = lambda
206            .evaluate(&[&values_param], |arrays| {
207                let indices = list_values_row_number(&list_array)?;
208                Ok(take_arrays(arrays, &indices, None)?)
209            })?
210            .into_array(list_values.len())?;
211
212        let predicate_bool = predicate_results
213            .as_any()
214            .downcast_ref::<BooleanArray>()
215            .ok_or_else(|| {
216                exec_datafusion_err!(
217                    "{} predicate must return boolean array",
218                    self.name()
219                )
220            })?;
221
222        let mut values = BooleanBuilder::with_capacity(list_array.len());
223
224        // Maps predicate results (flat over all elements) back to one Boolean per row.
225        // Uses adjusted offsets so sliced lists index correctly into the predicate array.
226        macro_rules! process_list {
227            ($list_typed:expr) => {{
228                let offsets = adjust_offsets_for_slice($list_typed);
229                for i in 0..$list_typed.len() {
230                    let start = offsets[i].as_usize();
231                    let end = offsets[i + 1].as_usize();
232                    // any_match_for_range returns None when nulls poison the result;
233                    // null rows produce an empty range and return Some(false), but their
234                    // null bit is preserved by attaching the original null bitmap below.
235                    values.append_option(any_match_for_range(predicate_bool, start, end));
236                }
237            }};
238        }
239
240        match list_array.data_type() {
241            DataType::List(_) => {
242                process_list!(list_array.as_list::<i32>());
243            }
244            DataType::LargeList(_) => {
245                process_list!(list_array.as_list::<i64>());
246            }
247            other => return exec_err!("expected list, got {other}"),
248        }
249
250        let (boolean_buffer, predicate_nulls) = values.finish().into_parts();
251        // Merge: a row is null if the input list row was null or the predicate returned null.
252        let nulls = NullBuffer::union(list_array.nulls(), predicate_nulls.as_ref());
253        Ok(ColumnarValue::Array(Arc::new(BooleanArray::new(
254            boolean_buffer,
255            nulls,
256        ))))
257    }
258
259    fn documentation(&self) -> Option<&Documentation> {
260        self.doc()
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use std::{collections::HashMap, sync::Arc};
267
268    use arrow::{
269        array::{ArrayRef, BooleanArray, Int32Array, ListArray, RecordBatch},
270        buffer::{NullBuffer, OffsetBuffer},
271        datatypes::{DataType, Field},
272    };
273    use datafusion_common::{DFSchema, Result};
274    use datafusion_expr::{
275        Expr, col,
276        execution_props::ExecutionProps,
277        expr::{HigherOrderFunction, LambdaVariable},
278        lambda, lit,
279    };
280    use datafusion_physical_expr::create_physical_expr;
281
282    use crate::array_any_match::array_any_match_higher_order_function;
283
284    fn run_any_match(
285        list: impl arrow::array::Array + Clone + 'static,
286    ) -> Result<ArrayRef> {
287        let schema = DFSchema::from_unqualified_fields(
288            vec![Field::new(
289                "list",
290                list.data_type().clone(),
291                list.is_nullable(),
292            )]
293            .into(),
294            HashMap::new(),
295        )?;
296
297        create_physical_expr(
298            &Expr::HigherOrderFunction(HigherOrderFunction::new(
299                array_any_match_higher_order_function(),
300                vec![
301                    col("list"),
302                    lambda(
303                        ["x"],
304                        Expr::LambdaVariable(LambdaVariable::new(
305                            "x".to_string(),
306                            Some(Arc::new(Field::new("x", DataType::Int32, true))),
307                        ))
308                        .gt(lit(2i32)),
309                    ),
310                ],
311            )),
312            &schema,
313            &ExecutionProps::new(),
314        )?
315        .evaluate(&RecordBatch::try_new(
316            Arc::clone(schema.inner()),
317            vec![Arc::new(list.clone())],
318        )?)?
319        .into_array(list.len())
320    }
321
322    fn run_any_match_div(
323        list: impl arrow::array::Array + Clone + 'static,
324    ) -> Result<ArrayRef> {
325        let schema = DFSchema::from_unqualified_fields(
326            vec![Field::new(
327                "list",
328                list.data_type().clone(),
329                list.is_nullable(),
330            )]
331            .into(),
332            HashMap::new(),
333        )?;
334
335        let x = Expr::LambdaVariable(LambdaVariable::new(
336            "x".to_string(),
337            Some(Arc::new(Field::new("x", DataType::Int32, true))),
338        ));
339        // predicate: (100 / x) > 5 — panics on divide by zero if x == 0 is evaluated
340        create_physical_expr(
341            &Expr::HigherOrderFunction(HigherOrderFunction::new(
342                array_any_match_higher_order_function(),
343                vec![col("list"), lambda(["x"], (lit(100i32) / x).gt(lit(5i32)))],
344            )),
345            &schema,
346            &ExecutionProps::new(),
347        )?
348        .evaluate(&RecordBatch::try_new(
349            Arc::clone(schema.inner()),
350            vec![Arc::new(list.clone())],
351        )?)?
352        .into_array(list.len())
353    }
354
355    fn make_list(values: Vec<i32>, offsets: OffsetBuffer<i32>) -> ListArray {
356        make_list_with_nulls(values, offsets, None)
357    }
358
359    fn make_list_with_nulls(
360        values: Vec<i32>,
361        offsets: OffsetBuffer<i32>,
362        nulls: Option<NullBuffer>,
363    ) -> ListArray {
364        ListArray::new(
365            Arc::new(Field::new_list_field(DataType::Int32, true)),
366            offsets,
367            Arc::new(Int32Array::from(values)),
368            nulls,
369        )
370    }
371
372    #[test]
373    fn test_any_match_some_true() -> Result<()> {
374        let list = make_list(vec![1, 2, 3], OffsetBuffer::from_lengths(vec![3]));
375        let result = run_any_match(list)?;
376        assert_eq!(
377            result.as_any().downcast_ref::<BooleanArray>().unwrap(),
378            &BooleanArray::from(vec![Some(true)])
379        );
380        Ok(())
381    }
382
383    #[test]
384    fn test_any_match_none_true() -> Result<()> {
385        let list = make_list(vec![1, 2], OffsetBuffer::from_lengths(vec![2]));
386        let result = run_any_match(list)?;
387        assert_eq!(
388            result.as_any().downcast_ref::<BooleanArray>().unwrap(),
389            &BooleanArray::from(vec![Some(false)])
390        );
391        Ok(())
392    }
393
394    #[test]
395    fn test_any_match_empty_array() -> Result<()> {
396        let list = make_list(vec![], OffsetBuffer::from_lengths(vec![0]));
397        let result = run_any_match(list)?;
398        assert_eq!(
399            result.as_any().downcast_ref::<BooleanArray>().unwrap(),
400            &BooleanArray::from(vec![Some(false)])
401        );
402        Ok(())
403    }
404
405    #[test]
406    fn test_any_match_multiple_rows() -> Result<()> {
407        let list = make_list(vec![1, 2, 3, 1, 2], OffsetBuffer::from_lengths(vec![3, 2]));
408        let result = run_any_match(list)?;
409        assert_eq!(
410            result.as_any().downcast_ref::<BooleanArray>().unwrap(),
411            &BooleanArray::from(vec![Some(true), Some(false)])
412        );
413        Ok(())
414    }
415
416    // Predicate must not be evaluated on elements belonging to null rows.
417    // The 10 in the null row would satisfy x > 5, but the row result must be None.
418    #[test]
419    fn test_any_match_should_not_evaluate_predicate_on_values_underlying_null()
420    -> Result<()> {
421        let list = make_list_with_nulls(
422            vec![1, 2, 10, 1, 2],
423            OffsetBuffer::from_lengths(vec![3, 2]),
424            Some(NullBuffer::from(vec![false, true])),
425        );
426        let result = run_any_match(list)?;
427        assert_eq!(
428            result.as_any().downcast_ref::<BooleanArray>().unwrap(),
429            &BooleanArray::from(vec![None, Some(false)])
430        );
431        Ok(())
432    }
433
434    // Predicate must not be evaluated on elements before the slice offset.
435    // The 10 before the slice would satisfy x > 5, but it is unreachable.
436    #[test]
437    fn test_any_match_on_sliced_list_should_not_evaluate_on_unreachable_values()
438    -> Result<()> {
439        let list = make_list(
440            vec![10, 1, 2, 1, 2],
441            OffsetBuffer::from_lengths(vec![1, 2, 2]),
442        )
443        .slice(1, 2);
444        let result = run_any_match(list)?;
445        assert_eq!(
446            result.as_any().downcast_ref::<BooleanArray>().unwrap(),
447            &BooleanArray::from(vec![Some(false), Some(false)])
448        );
449        Ok(())
450    }
451
452    // 0 in the null row would cause divide by zero if the predicate is evaluated on it.
453    #[test]
454    fn test_any_match_does_not_evaluate_predicate_on_null_row_values() -> Result<()> {
455        let list = make_list_with_nulls(
456            vec![1, 2, 0, 4, 5],
457            OffsetBuffer::from_lengths(vec![3, 2]),
458            Some(NullBuffer::from(vec![false, true])),
459        );
460        let result = run_any_match_div(list)?;
461        assert_eq!(
462            result.as_any().downcast_ref::<BooleanArray>().unwrap(),
463            &BooleanArray::from(vec![None, Some(true)])
464        );
465        Ok(())
466    }
467
468    // 0 before the slice offset would cause divide by zero if evaluated.
469    #[test]
470    fn test_any_match_does_not_evaluate_predicate_on_unreachable_values() -> Result<()> {
471        let list = make_list(
472            vec![0, 4, 5, 50, 100],
473            OffsetBuffer::from_lengths(vec![1, 2, 2]),
474        )
475        .slice(1, 2);
476        let result = run_any_match_div(list)?;
477        assert_eq!(
478            result.as_any().downcast_ref::<BooleanArray>().unwrap(),
479            &BooleanArray::from(vec![Some(true), Some(false)])
480        );
481        Ok(())
482    }
483}