Skip to main content

datafusion_functions_nested/
array_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`datafusion_expr::HigherOrderUDF`] definitions for array_filter function.
19
20use arrow::{
21    array::{
22        Array, ArrayRef, AsArray, BooleanArray, LargeListArray, ListArray,
23        OffsetBufferBuilder, OffsetSizeTrait, new_empty_array,
24    },
25    buffer::{OffsetBuffer, ScalarBuffer},
26    compute::{filter as arrow_filter, take_arrays},
27    datatypes::{DataType, Field, FieldRef},
28};
29use datafusion_common::{
30    Result, ScalarValue, exec_err,
31    utils::{adjust_offsets_for_slice, list_values_row_number},
32};
33use datafusion_expr::{
34    ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs,
35    HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda,
36    Volatility,
37};
38use datafusion_macros::user_doc;
39use std::sync::Arc;
40
41use crate::lambda_utils::{
42    ListValuesResult, coerce_single_list_arg, extract_list_values,
43    single_list_lambda_parameters, value_lambda_pair,
44};
45
46make_higher_order_function_expr_and_func!(
47    ArrayFilter,
48    array_filter,
49    array lambda,
50    "filters the values of an array using a boolean lambda",
51    array_filter_higher_order_function
52);
53
54#[user_doc(
55    doc_section(label = "Array Functions"),
56    description = "filters the values of an array using a boolean lambda",
57    syntax_example = "array_filter(array, x -> x > 2)",
58    sql_example = r#"```sql
59> select array_filter([1, 2, 3, 4, 5], x -> x > 2);
60+--------------------------------------------+
61| array_filter([1, 2, 3, 4, 5], x -> x > 2) |
62+--------------------------------------------+
63| [3, 4, 5]                                  |
64+--------------------------------------------+
65```"#,
66    argument(
67        name = "array",
68        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
69    ),
70    argument(
71        name = "lambda",
72        description = "Lambda that returns a boolean. Elements for which the lambda returns true are kept."
73    )
74)]
75#[derive(Debug, PartialEq, Eq, Hash)]
76pub struct ArrayFilter {
77    signature: HigherOrderSignature,
78    aliases: Vec<String>,
79}
80
81impl Default for ArrayFilter {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87impl ArrayFilter {
88    pub fn new() -> Self {
89        Self {
90            signature: HigherOrderSignature::exact(
91                vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
92                Volatility::Immutable,
93            ),
94            aliases: vec![String::from("list_filter")],
95        }
96    }
97}
98
99impl HigherOrderUDFImpl for ArrayFilter {
100    fn name(&self) -> &str {
101        "array_filter"
102    }
103
104    fn aliases(&self) -> &[String] {
105        &self.aliases
106    }
107
108    fn signature(&self) -> &HigherOrderSignature {
109        &self.signature
110    }
111
112    fn lambda_parameters(
113        &self,
114        _step: usize,
115        fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
116    ) -> Result<LambdaParametersProgress> {
117        single_list_lambda_parameters(self.name(), fields)
118    }
119
120    fn return_field_from_args(
121        &self,
122        args: HigherOrderReturnFieldArgs,
123    ) -> Result<Arc<Field>> {
124        let (list, _lambda) = value_lambda_pair(self.name(), args.arg_fields)?;
125        Ok(Arc::new(Field::new(
126            "",
127            list.data_type().clone(),
128            list.is_nullable(),
129        )))
130    }
131
132    fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
133        let (list, lambda) = value_lambda_pair(self.name(), &args.args)?;
134        let list_array = list.to_array(args.number_rows)?;
135
136        let list_values = match extract_list_values(&list_array, args.return_type())? {
137            ListValuesResult::EarlyReturn(v) => return Ok(v),
138            ListValuesResult::Values(v) => v,
139        };
140
141        let field = match args.return_field.data_type() {
142            DataType::List(field) | DataType::LargeList(field) => Arc::clone(field),
143            _ => {
144                return exec_err!(
145                    "{} expected return_field to be a list, got {}",
146                    self.name(),
147                    args.return_field
148                );
149            }
150        };
151
152        let values_param = || Ok(Arc::clone(&list_values));
153        let predicate_output = lambda.evaluate(&[&values_param], |arrays| {
154            let indices = list_values_row_number(&list_array)?;
155            Ok(take_arrays(arrays, &indices, None)?)
156        })?;
157
158        // Scalar predicate short-circuit: x -> true or x -> false/null
159        if let ColumnarValue::Scalar(ScalarValue::Boolean(b)) = &predicate_output {
160            return match b {
161                Some(true) => Ok(ColumnarValue::Array(list_array)),
162                _ => Ok(ColumnarValue::Array(empty_filtered_list(
163                    &list_array,
164                    field,
165                )?)),
166            };
167        }
168
169        let predicate = predicate_output.into_array(list_values.len())?;
170        let Some(predicate) = predicate.as_any().downcast_ref::<BooleanArray>() else {
171            return exec_err!(
172                "{} lambda must return boolean, got {}",
173                self.name(),
174                predicate.data_type()
175            );
176        };
177
178        // ListView and LargeListView are coerced to List/LargeList by coerce_value_types.
179        let filtered_list = match list_array.data_type() {
180            DataType::List(_) => {
181                let list = list_array.as_list::<i32>();
182                let adjusted_offsets = adjust_offsets_for_slice(list);
183                let (filtered_values, new_offsets) =
184                    filter_list_values(&list_values, predicate, &adjusted_offsets)?;
185                Arc::new(ListArray::new(
186                    field,
187                    new_offsets,
188                    filtered_values,
189                    list.nulls().cloned(),
190                )) as ArrayRef
191            }
192            DataType::LargeList(_) => {
193                let large_list = list_array.as_list::<i64>();
194                let adjusted_offsets = adjust_offsets_for_slice(large_list);
195                let (filtered_values, new_offsets) =
196                    filter_list_values(&list_values, predicate, &adjusted_offsets)?;
197                Arc::new(LargeListArray::new(
198                    field,
199                    new_offsets,
200                    filtered_values,
201                    large_list.nulls().cloned(),
202                ))
203            }
204            other => exec_err!("expected list, got {other}")?,
205        };
206
207        Ok(ColumnarValue::Array(filtered_list))
208    }
209
210    fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
211        coerce_single_list_arg(self.name(), arg_types)
212    }
213
214    fn documentation(&self) -> Option<&Documentation> {
215        self.doc()
216    }
217}
218
219/// Returns a list array with every non-null sublist emptied, preserving the null buffer.
220/// Used for the `x -> false` / `x -> null` scalar predicate short-circuit.
221fn empty_filtered_list(list_array: &ArrayRef, field: FieldRef) -> Result<ArrayRef> {
222    let n = list_array.len();
223    let empty_values = new_empty_array(field.data_type());
224    Ok(match list_array.data_type() {
225        DataType::List(_) => {
226            let list = list_array.as_list::<i32>();
227            Arc::new(ListArray::new(
228                field,
229                OffsetBuffer::new(ScalarBuffer::from(vec![0i32; n + 1])),
230                empty_values,
231                list.nulls().cloned(),
232            ))
233        }
234        DataType::LargeList(_) => {
235            let list = list_array.as_list::<i64>();
236            Arc::new(LargeListArray::new(
237                field,
238                OffsetBuffer::new(ScalarBuffer::from(vec![0i64; n + 1])),
239                empty_values,
240                list.nulls().cloned(),
241            ))
242        }
243        other => return exec_err!("expected list, got {other}"),
244    })
245}
246
247/// Filters flat list values using a boolean predicate, returning filtered values and
248/// recomputed per-sublist offsets. Null predicate values are treated as false.
249fn filter_list_values<O: OffsetSizeTrait>(
250    values: &ArrayRef,
251    predicate: &BooleanArray,
252    offsets: &OffsetBuffer<O>,
253) -> Result<(ArrayRef, OffsetBuffer<O>)> {
254    let num_sublists = offsets.len().saturating_sub(1);
255    let mut builder = OffsetBufferBuilder::<O>::new(num_sublists);
256
257    let has_nulls = predicate.null_count() > 0;
258    for i in 0..num_sublists {
259        let start = offsets[i].as_usize();
260        let end = offsets[i + 1].as_usize();
261        let count = if has_nulls {
262            (start..end)
263                .filter(|&j| predicate.is_valid(j) && predicate.value(j))
264                .count()
265        } else {
266            predicate
267                .values()
268                .slice(start, end - start)
269                .count_set_bits()
270        };
271        builder.push_length(count);
272    }
273
274    let new_offsets = builder.finish();
275
276    if new_offsets.last() == offsets.last() {
277        return Ok((Arc::clone(values), offsets.clone()));
278    }
279
280    // arrow_filter treats null predicate values as false
281    let filtered_values = arrow_filter(values.as_ref(), predicate)?;
282    Ok((filtered_values, new_offsets))
283}
284
285#[cfg(test)]
286mod tests {
287    use arrow::{
288        array::{Array, AsArray},
289        buffer::{NullBuffer, OffsetBuffer},
290    };
291
292    use crate::array_filter::array_filter_higher_order_function;
293    use crate::lambda_utils::test_utils::{create_i32_list, eval_hof_on_i32_list, v};
294    use datafusion_expr::lit;
295
296    fn keep_greater_than_two(
297        list: impl Array + Clone + 'static,
298    ) -> datafusion_common::Result<arrow::array::ArrayRef> {
299        eval_hof_on_i32_list(
300            array_filter_higher_order_function(),
301            list,
302            v().gt(lit(2i32)),
303        )
304    }
305
306    #[test]
307    fn filter_basic() {
308        let list = create_i32_list(
309            vec![1, 2, 3, 4, 5],
310            OffsetBuffer::<i32>::from_lengths(vec![5]),
311            None,
312        );
313
314        let res = keep_greater_than_two(list).unwrap();
315        let actual = res.as_list::<i32>();
316
317        let expected = create_i32_list(
318            vec![3, 4, 5],
319            OffsetBuffer::<i32>::from_lengths(vec![3]),
320            None,
321        );
322
323        assert_eq!(actual, &expected);
324    }
325
326    #[test]
327    fn filter_multiple_sublists() {
328        let list = create_i32_list(
329            vec![1, 5, 2, 4, 3],
330            OffsetBuffer::<i32>::from_lengths(vec![2, 3]),
331            None,
332        );
333
334        let res = keep_greater_than_two(list).unwrap();
335        let actual = res.as_list::<i32>();
336
337        // [1,5] -> [5], [2,4,3] -> [4,3]
338        let expected = create_i32_list(
339            vec![5, 4, 3],
340            OffsetBuffer::<i32>::from_lengths(vec![1, 2]),
341            None,
342        );
343
344        assert_eq!(actual, &expected);
345    }
346
347    #[test]
348    fn filter_on_sliced_list_should_not_evaluate_on_unreachable_values() {
349        // First sublist [0] is sliced away; sliced array covers sublists [1..3]
350        let list = create_i32_list(
351            vec![
352                0, // unreachable after slice — if evaluated, it would appear in output
353                1, 5, 2, 4, 3, 7,
354            ],
355            OffsetBuffer::<i32>::from_lengths(vec![1, 3, 3]),
356            None,
357        )
358        .slice(1, 2);
359
360        let res = keep_greater_than_two(list).unwrap();
361        let actual = res.as_list::<i32>();
362
363        // [1,5,2] -> [5], [4,3,7] -> [4,3,7]
364        let expected = create_i32_list(
365            vec![5, 4, 3, 7],
366            OffsetBuffer::<i32>::from_lengths(vec![1, 3]),
367            None,
368        );
369
370        assert_eq!(actual, &expected);
371    }
372
373    #[test]
374    fn filter_should_not_be_evaluated_on_values_underlying_null() {
375        // The null sublist (index 1) contains values that would pass the predicate
376        // if evaluated. We verify they do NOT appear in the output.
377        let list = create_i32_list(
378            vec![1, 5, 99, 100, 3, 7],
379            OffsetBuffer::<i32>::from_lengths(vec![2, 2, 2]),
380            Some(NullBuffer::from(vec![true, false, true])),
381        );
382
383        let res = keep_greater_than_two(list).unwrap();
384        let actual = res.as_list::<i32>();
385
386        // sublist 0: [1,5] -> [5]
387        // sublist 1: null  -> null (empty range, null bit)
388        // sublist 2: [3,7] -> [3,7]
389        let expected = create_i32_list(
390            vec![5, 3, 7],
391            OffsetBuffer::<i32>::from_lengths(vec![1, 0, 2]),
392            Some(NullBuffer::from(vec![true, false, true])),
393        );
394
395        assert_eq!(actual.data_type(), expected.data_type());
396        assert_eq!(actual, &expected);
397    }
398
399    #[test]
400    fn filter_all_filtered_out() {
401        let list =
402            create_i32_list(vec![1, 2], OffsetBuffer::<i32>::from_lengths(vec![2]), None);
403
404        let res = keep_greater_than_two(list).unwrap();
405        let actual = res.as_list::<i32>();
406
407        let expected = create_i32_list(
408            vec![0i32; 0],
409            OffsetBuffer::<i32>::from_lengths(vec![0]),
410            None,
411        );
412
413        assert_eq!(actual, &expected);
414    }
415
416    #[test]
417    fn filter_nothing_filtered_reuses_values() {
418        let list = create_i32_list(
419            vec![3, 4, 5],
420            OffsetBuffer::<i32>::from_lengths(vec![3]),
421            None,
422        );
423        // all elements > 2, so nothing is filtered — values buffer should be reused
424        let res = keep_greater_than_two(list.clone()).unwrap();
425        assert_eq!(res.as_list::<i32>(), &list);
426    }
427
428    #[test]
429    fn scalar_true_predicate_returns_original_list() {
430        let list = create_i32_list(
431            vec![1, 2, 3],
432            OffsetBuffer::<i32>::from_lengths(vec![3]),
433            None,
434        );
435        // x -> true: every element kept, should return list unchanged
436        let res = eval_hof_on_i32_list(
437            array_filter_higher_order_function(),
438            list.clone(),
439            lit(true),
440        )
441        .unwrap();
442        assert_eq!(res.as_list::<i32>(), &list);
443    }
444
445    #[test]
446    fn scalar_false_predicate_returns_empty_sublists() {
447        let list = create_i32_list(
448            vec![1, 2, 3, 4],
449            OffsetBuffer::<i32>::from_lengths(vec![2, 2]),
450            None,
451        );
452        // x -> false: every sublist emptied
453        let res =
454            eval_hof_on_i32_list(array_filter_higher_order_function(), list, lit(false))
455                .unwrap();
456        let actual = res.as_list::<i32>();
457        let expected = create_i32_list(
458            vec![0i32; 0],
459            OffsetBuffer::<i32>::from_lengths(vec![0, 0]),
460            None,
461        );
462        assert_eq!(actual, &expected);
463    }
464}