1use arrow::{
21 array::{
22 Array, ArrayRef, AsArray, BooleanArray, LargeListArray, ListArray,
23 OffsetBufferBuilder, OffsetSizeTrait, new_empty_array,
24 },
25 buffer::{OffsetBuffer, ScalarBuffer},
26 compute::{filter as arrow_filter, take_arrays},
27 datatypes::{DataType, Field, FieldRef},
28};
29use datafusion_common::{
30 Result, ScalarValue, exec_err,
31 utils::{adjust_offsets_for_slice, list_values_row_number},
32};
33use datafusion_expr::{
34 ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs,
35 HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda,
36 Volatility,
37};
38use datafusion_macros::user_doc;
39use std::sync::Arc;
40
41use crate::lambda_utils::{
42 ListValuesResult, coerce_single_list_arg, extract_list_values,
43 single_list_lambda_parameters, value_lambda_pair,
44};
45
46make_higher_order_function_expr_and_func!(
47 ArrayFilter,
48 array_filter,
49 array lambda,
50 "filters the values of an array using a boolean lambda",
51 array_filter_higher_order_function
52);
53
54#[user_doc(
55 doc_section(label = "Array Functions"),
56 description = "filters the values of an array using a boolean lambda",
57 syntax_example = "array_filter(array, x -> x > 2)",
58 sql_example = r#"```sql
59> select array_filter([1, 2, 3, 4, 5], x -> x > 2);
60+--------------------------------------------+
61| array_filter([1, 2, 3, 4, 5], x -> x > 2) |
62+--------------------------------------------+
63| [3, 4, 5] |
64+--------------------------------------------+
65```"#,
66 argument(
67 name = "array",
68 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
69 ),
70 argument(
71 name = "lambda",
72 description = "Lambda that returns a boolean. Elements for which the lambda returns true are kept."
73 )
74)]
75#[derive(Debug, PartialEq, Eq, Hash)]
76pub struct ArrayFilter {
77 signature: HigherOrderSignature,
78 aliases: Vec<String>,
79}
80
81impl Default for ArrayFilter {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl ArrayFilter {
88 pub fn new() -> Self {
89 Self {
90 signature: HigherOrderSignature::exact(
91 vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
92 Volatility::Immutable,
93 ),
94 aliases: vec![String::from("list_filter")],
95 }
96 }
97}
98
99impl HigherOrderUDFImpl for ArrayFilter {
100 fn name(&self) -> &str {
101 "array_filter"
102 }
103
104 fn aliases(&self) -> &[String] {
105 &self.aliases
106 }
107
108 fn signature(&self) -> &HigherOrderSignature {
109 &self.signature
110 }
111
112 fn lambda_parameters(
113 &self,
114 _step: usize,
115 fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
116 ) -> Result<LambdaParametersProgress> {
117 single_list_lambda_parameters(self.name(), fields)
118 }
119
120 fn return_field_from_args(
121 &self,
122 args: HigherOrderReturnFieldArgs,
123 ) -> Result<Arc<Field>> {
124 let (list, _lambda) = value_lambda_pair(self.name(), args.arg_fields)?;
125 Ok(Arc::new(Field::new(
126 "",
127 list.data_type().clone(),
128 list.is_nullable(),
129 )))
130 }
131
132 fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
133 let (list, lambda) = value_lambda_pair(self.name(), &args.args)?;
134 let list_array = list.to_array(args.number_rows)?;
135
136 let list_values = match extract_list_values(&list_array, args.return_type())? {
137 ListValuesResult::EarlyReturn(v) => return Ok(v),
138 ListValuesResult::Values(v) => v,
139 };
140
141 let field = match args.return_field.data_type() {
142 DataType::List(field) | DataType::LargeList(field) => Arc::clone(field),
143 _ => {
144 return exec_err!(
145 "{} expected return_field to be a list, got {}",
146 self.name(),
147 args.return_field
148 );
149 }
150 };
151
152 let values_param = || Ok(Arc::clone(&list_values));
153 let predicate_output = lambda.evaluate(&[&values_param], |arrays| {
154 let indices = list_values_row_number(&list_array)?;
155 Ok(take_arrays(arrays, &indices, None)?)
156 })?;
157
158 if let ColumnarValue::Scalar(ScalarValue::Boolean(b)) = &predicate_output {
160 return match b {
161 Some(true) => Ok(ColumnarValue::Array(list_array)),
162 _ => Ok(ColumnarValue::Array(empty_filtered_list(
163 &list_array,
164 field,
165 )?)),
166 };
167 }
168
169 let predicate = predicate_output.into_array(list_values.len())?;
170 let Some(predicate) = predicate.as_any().downcast_ref::<BooleanArray>() else {
171 return exec_err!(
172 "{} lambda must return boolean, got {}",
173 self.name(),
174 predicate.data_type()
175 );
176 };
177
178 let filtered_list = match list_array.data_type() {
180 DataType::List(_) => {
181 let list = list_array.as_list::<i32>();
182 let adjusted_offsets = adjust_offsets_for_slice(list);
183 let (filtered_values, new_offsets) =
184 filter_list_values(&list_values, predicate, &adjusted_offsets)?;
185 Arc::new(ListArray::new(
186 field,
187 new_offsets,
188 filtered_values,
189 list.nulls().cloned(),
190 )) as ArrayRef
191 }
192 DataType::LargeList(_) => {
193 let large_list = list_array.as_list::<i64>();
194 let adjusted_offsets = adjust_offsets_for_slice(large_list);
195 let (filtered_values, new_offsets) =
196 filter_list_values(&list_values, predicate, &adjusted_offsets)?;
197 Arc::new(LargeListArray::new(
198 field,
199 new_offsets,
200 filtered_values,
201 large_list.nulls().cloned(),
202 ))
203 }
204 other => exec_err!("expected list, got {other}")?,
205 };
206
207 Ok(ColumnarValue::Array(filtered_list))
208 }
209
210 fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
211 coerce_single_list_arg(self.name(), arg_types)
212 }
213
214 fn documentation(&self) -> Option<&Documentation> {
215 self.doc()
216 }
217}
218
219fn empty_filtered_list(list_array: &ArrayRef, field: FieldRef) -> Result<ArrayRef> {
222 let n = list_array.len();
223 let empty_values = new_empty_array(field.data_type());
224 Ok(match list_array.data_type() {
225 DataType::List(_) => {
226 let list = list_array.as_list::<i32>();
227 Arc::new(ListArray::new(
228 field,
229 OffsetBuffer::new(ScalarBuffer::from(vec![0i32; n + 1])),
230 empty_values,
231 list.nulls().cloned(),
232 ))
233 }
234 DataType::LargeList(_) => {
235 let list = list_array.as_list::<i64>();
236 Arc::new(LargeListArray::new(
237 field,
238 OffsetBuffer::new(ScalarBuffer::from(vec![0i64; n + 1])),
239 empty_values,
240 list.nulls().cloned(),
241 ))
242 }
243 other => return exec_err!("expected list, got {other}"),
244 })
245}
246
247fn filter_list_values<O: OffsetSizeTrait>(
250 values: &ArrayRef,
251 predicate: &BooleanArray,
252 offsets: &OffsetBuffer<O>,
253) -> Result<(ArrayRef, OffsetBuffer<O>)> {
254 let num_sublists = offsets.len().saturating_sub(1);
255 let mut builder = OffsetBufferBuilder::<O>::new(num_sublists);
256
257 let has_nulls = predicate.null_count() > 0;
258 for i in 0..num_sublists {
259 let start = offsets[i].as_usize();
260 let end = offsets[i + 1].as_usize();
261 let count = if has_nulls {
262 (start..end)
263 .filter(|&j| predicate.is_valid(j) && predicate.value(j))
264 .count()
265 } else {
266 predicate
267 .values()
268 .slice(start, end - start)
269 .count_set_bits()
270 };
271 builder.push_length(count);
272 }
273
274 let new_offsets = builder.finish();
275
276 if new_offsets.last() == offsets.last() {
277 return Ok((Arc::clone(values), offsets.clone()));
278 }
279
280 let filtered_values = arrow_filter(values.as_ref(), predicate)?;
282 Ok((filtered_values, new_offsets))
283}
284
285#[cfg(test)]
286mod tests {
287 use arrow::{
288 array::{Array, AsArray},
289 buffer::{NullBuffer, OffsetBuffer},
290 };
291
292 use crate::array_filter::array_filter_higher_order_function;
293 use crate::lambda_utils::test_utils::{create_i32_list, eval_hof_on_i32_list, v};
294 use datafusion_expr::lit;
295
296 fn keep_greater_than_two(
297 list: impl Array + Clone + 'static,
298 ) -> datafusion_common::Result<arrow::array::ArrayRef> {
299 eval_hof_on_i32_list(
300 array_filter_higher_order_function(),
301 list,
302 v().gt(lit(2i32)),
303 )
304 }
305
306 #[test]
307 fn filter_basic() {
308 let list = create_i32_list(
309 vec![1, 2, 3, 4, 5],
310 OffsetBuffer::<i32>::from_lengths(vec![5]),
311 None,
312 );
313
314 let res = keep_greater_than_two(list).unwrap();
315 let actual = res.as_list::<i32>();
316
317 let expected = create_i32_list(
318 vec![3, 4, 5],
319 OffsetBuffer::<i32>::from_lengths(vec![3]),
320 None,
321 );
322
323 assert_eq!(actual, &expected);
324 }
325
326 #[test]
327 fn filter_multiple_sublists() {
328 let list = create_i32_list(
329 vec![1, 5, 2, 4, 3],
330 OffsetBuffer::<i32>::from_lengths(vec![2, 3]),
331 None,
332 );
333
334 let res = keep_greater_than_two(list).unwrap();
335 let actual = res.as_list::<i32>();
336
337 let expected = create_i32_list(
339 vec![5, 4, 3],
340 OffsetBuffer::<i32>::from_lengths(vec![1, 2]),
341 None,
342 );
343
344 assert_eq!(actual, &expected);
345 }
346
347 #[test]
348 fn filter_on_sliced_list_should_not_evaluate_on_unreachable_values() {
349 let list = create_i32_list(
351 vec![
352 0, 1, 5, 2, 4, 3, 7,
354 ],
355 OffsetBuffer::<i32>::from_lengths(vec![1, 3, 3]),
356 None,
357 )
358 .slice(1, 2);
359
360 let res = keep_greater_than_two(list).unwrap();
361 let actual = res.as_list::<i32>();
362
363 let expected = create_i32_list(
365 vec![5, 4, 3, 7],
366 OffsetBuffer::<i32>::from_lengths(vec![1, 3]),
367 None,
368 );
369
370 assert_eq!(actual, &expected);
371 }
372
373 #[test]
374 fn filter_should_not_be_evaluated_on_values_underlying_null() {
375 let list = create_i32_list(
378 vec![1, 5, 99, 100, 3, 7],
379 OffsetBuffer::<i32>::from_lengths(vec![2, 2, 2]),
380 Some(NullBuffer::from(vec![true, false, true])),
381 );
382
383 let res = keep_greater_than_two(list).unwrap();
384 let actual = res.as_list::<i32>();
385
386 let expected = create_i32_list(
390 vec![5, 3, 7],
391 OffsetBuffer::<i32>::from_lengths(vec![1, 0, 2]),
392 Some(NullBuffer::from(vec![true, false, true])),
393 );
394
395 assert_eq!(actual.data_type(), expected.data_type());
396 assert_eq!(actual, &expected);
397 }
398
399 #[test]
400 fn filter_all_filtered_out() {
401 let list =
402 create_i32_list(vec![1, 2], OffsetBuffer::<i32>::from_lengths(vec![2]), None);
403
404 let res = keep_greater_than_two(list).unwrap();
405 let actual = res.as_list::<i32>();
406
407 let expected = create_i32_list(
408 vec![0i32; 0],
409 OffsetBuffer::<i32>::from_lengths(vec![0]),
410 None,
411 );
412
413 assert_eq!(actual, &expected);
414 }
415
416 #[test]
417 fn filter_nothing_filtered_reuses_values() {
418 let list = create_i32_list(
419 vec![3, 4, 5],
420 OffsetBuffer::<i32>::from_lengths(vec![3]),
421 None,
422 );
423 let res = keep_greater_than_two(list.clone()).unwrap();
425 assert_eq!(res.as_list::<i32>(), &list);
426 }
427
428 #[test]
429 fn scalar_true_predicate_returns_original_list() {
430 let list = create_i32_list(
431 vec![1, 2, 3],
432 OffsetBuffer::<i32>::from_lengths(vec![3]),
433 None,
434 );
435 let res = eval_hof_on_i32_list(
437 array_filter_higher_order_function(),
438 list.clone(),
439 lit(true),
440 )
441 .unwrap();
442 assert_eq!(res.as_list::<i32>(), &list);
443 }
444
445 #[test]
446 fn scalar_false_predicate_returns_empty_sublists() {
447 let list = create_i32_list(
448 vec![1, 2, 3, 4],
449 OffsetBuffer::<i32>::from_lengths(vec![2, 2]),
450 None,
451 );
452 let res =
454 eval_hof_on_i32_list(array_filter_higher_order_function(), list, lit(false))
455 .unwrap();
456 let actual = res.as_list::<i32>();
457 let expected = create_i32_list(
458 vec![0i32; 0],
459 OffsetBuffer::<i32>::from_lengths(vec![0, 0]),
460 None,
461 );
462 assert_eq!(actual, &expected);
463 }
464}