datafusion_functions_nested/
remove.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions.
19
20use crate::utils;
21use crate::utils::make_scalar_function;
22use arrow::array::{
23    Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, cast::AsArray,
24    new_empty_array,
25};
26use arrow::buffer::OffsetBuffer;
27use arrow::datatypes::{DataType, FieldRef};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
31use datafusion_expr::{
32    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33    ScalarUDFImpl, Signature, TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::any::Any;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40    ArrayRemove,
41    array_remove,
42    array element,
43    "removes the first element from the array equal to the given value.",
44    array_remove_udf
45);
46
47#[user_doc(
48    doc_section(label = "Array Functions"),
49    description = "Removes the first element from the array equal to the given value.",
50    syntax_example = "array_remove(array, element)",
51    sql_example = r#"```sql
52> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
53+----------------------------------------------+
54| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
55+----------------------------------------------+
56| [1, 2, 3, 2, 1, 4]                           |
57+----------------------------------------------+
58```"#,
59    argument(
60        name = "array",
61        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62    ),
63    argument(
64        name = "element",
65        description = "Element to be removed from the array."
66    )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct ArrayRemove {
70    signature: Signature,
71    aliases: Vec<String>,
72}
73
74impl Default for ArrayRemove {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl ArrayRemove {
81    pub fn new() -> Self {
82        Self {
83            signature: Signature::array_and_element(Volatility::Immutable),
84            aliases: vec!["list_remove".to_string()],
85        }
86    }
87}
88
89impl ScalarUDFImpl for ArrayRemove {
90    fn as_any(&self) -> &dyn Any {
91        self
92    }
93
94    fn name(&self) -> &str {
95        "array_remove"
96    }
97
98    fn signature(&self) -> &Signature {
99        &self.signature
100    }
101
102    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
103        internal_err!("return_field_from_args should be used instead")
104    }
105
106    fn return_field_from_args(
107        &self,
108        args: datafusion_expr::ReturnFieldArgs,
109    ) -> Result<FieldRef> {
110        Ok(Arc::clone(&args.arg_fields[0]))
111    }
112
113    fn invoke_with_args(
114        &self,
115        args: datafusion_expr::ScalarFunctionArgs,
116    ) -> Result<ColumnarValue> {
117        make_scalar_function(array_remove_inner)(&args.args)
118    }
119
120    fn aliases(&self) -> &[String] {
121        &self.aliases
122    }
123
124    fn documentation(&self) -> Option<&Documentation> {
125        self.doc()
126    }
127}
128
129make_udf_expr_and_func!(
130    ArrayRemoveN,
131    array_remove_n,
132    array element max,
133    "removes the first `max` elements from the array equal to the given value.",
134    array_remove_n_udf
135);
136
137#[user_doc(
138    doc_section(label = "Array Functions"),
139    description = "Removes the first `max` elements from the array equal to the given value.",
140    syntax_example = "array_remove_n(array, element, max))",
141    sql_example = r#"```sql
142> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
143+---------------------------------------------------------+
144| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
145+---------------------------------------------------------+
146| [1, 3, 2, 1, 4]                                         |
147+---------------------------------------------------------+
148```"#,
149    argument(
150        name = "array",
151        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
152    ),
153    argument(
154        name = "element",
155        description = "Element to be removed from the array."
156    ),
157    argument(name = "max", description = "Number of first occurrences to remove.")
158)]
159#[derive(Debug, PartialEq, Eq, Hash)]
160pub(super) struct ArrayRemoveN {
161    signature: Signature,
162    aliases: Vec<String>,
163}
164
165impl ArrayRemoveN {
166    pub fn new() -> Self {
167        Self {
168            signature: Signature::new(
169                TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
170                    arguments: vec![
171                        ArrayFunctionArgument::Array,
172                        ArrayFunctionArgument::Element,
173                        ArrayFunctionArgument::Index,
174                    ],
175                    array_coercion: Some(ListCoercion::FixedSizedListToList),
176                }),
177                Volatility::Immutable,
178            ),
179            aliases: vec!["list_remove_n".to_string()],
180        }
181    }
182}
183
184impl ScalarUDFImpl for ArrayRemoveN {
185    fn as_any(&self) -> &dyn Any {
186        self
187    }
188
189    fn name(&self) -> &str {
190        "array_remove_n"
191    }
192
193    fn signature(&self) -> &Signature {
194        &self.signature
195    }
196
197    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
198        internal_err!("return_field_from_args should be used instead")
199    }
200
201    fn return_field_from_args(
202        &self,
203        args: datafusion_expr::ReturnFieldArgs,
204    ) -> Result<FieldRef> {
205        Ok(Arc::clone(&args.arg_fields[0]))
206    }
207
208    fn invoke_with_args(
209        &self,
210        args: datafusion_expr::ScalarFunctionArgs,
211    ) -> Result<ColumnarValue> {
212        make_scalar_function(array_remove_n_inner)(&args.args)
213    }
214
215    fn aliases(&self) -> &[String] {
216        &self.aliases
217    }
218
219    fn documentation(&self) -> Option<&Documentation> {
220        self.doc()
221    }
222}
223
224make_udf_expr_and_func!(
225    ArrayRemoveAll,
226    array_remove_all,
227    array element,
228    "removes all elements from the array equal to the given value.",
229    array_remove_all_udf
230);
231
232#[user_doc(
233    doc_section(label = "Array Functions"),
234    description = "Removes all elements from the array equal to the given value.",
235    syntax_example = "array_remove_all(array, element)",
236    sql_example = r#"```sql
237> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
238+--------------------------------------------------+
239| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
240+--------------------------------------------------+
241| [1, 3, 1, 4]                                     |
242+--------------------------------------------------+
243```"#,
244    argument(
245        name = "array",
246        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
247    ),
248    argument(
249        name = "element",
250        description = "Element to be removed from the array."
251    )
252)]
253#[derive(Debug, PartialEq, Eq, Hash)]
254pub(super) struct ArrayRemoveAll {
255    signature: Signature,
256    aliases: Vec<String>,
257}
258
259impl ArrayRemoveAll {
260    pub fn new() -> Self {
261        Self {
262            signature: Signature::array_and_element(Volatility::Immutable),
263            aliases: vec!["list_remove_all".to_string()],
264        }
265    }
266}
267
268impl ScalarUDFImpl for ArrayRemoveAll {
269    fn as_any(&self) -> &dyn Any {
270        self
271    }
272
273    fn name(&self) -> &str {
274        "array_remove_all"
275    }
276
277    fn signature(&self) -> &Signature {
278        &self.signature
279    }
280
281    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
282        internal_err!("return_field_from_args should be used instead")
283    }
284
285    fn return_field_from_args(
286        &self,
287        args: datafusion_expr::ReturnFieldArgs,
288    ) -> Result<FieldRef> {
289        Ok(Arc::clone(&args.arg_fields[0]))
290    }
291
292    fn invoke_with_args(
293        &self,
294        args: datafusion_expr::ScalarFunctionArgs,
295    ) -> Result<ColumnarValue> {
296        make_scalar_function(array_remove_all_inner)(&args.args)
297    }
298
299    fn aliases(&self) -> &[String] {
300        &self.aliases
301    }
302
303    fn documentation(&self) -> Option<&Documentation> {
304        self.doc()
305    }
306}
307
308fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
309    let [array, element] = take_function_args("array_remove", args)?;
310
311    let arr_n = vec![1; array.len()];
312    array_remove_internal(array, element, &arr_n)
313}
314
315fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
316    let [array, element, max] = take_function_args("array_remove_n", args)?;
317
318    let arr_n = as_int64_array(max)?.values().to_vec();
319    array_remove_internal(array, element, &arr_n)
320}
321
322fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
323    let [array, element] = take_function_args("array_remove_all", args)?;
324
325    let arr_n = vec![i64::MAX; array.len()];
326    array_remove_internal(array, element, &arr_n)
327}
328
329fn array_remove_internal(
330    array: &ArrayRef,
331    element_array: &ArrayRef,
332    arr_n: &[i64],
333) -> Result<ArrayRef> {
334    match array.data_type() {
335        DataType::List(_) => {
336            let list_array = array.as_list::<i32>();
337            general_remove::<i32>(list_array, element_array, arr_n)
338        }
339        DataType::LargeList(_) => {
340            let list_array = array.as_list::<i64>();
341            general_remove::<i64>(list_array, element_array, arr_n)
342        }
343        array_type => {
344            exec_err!("array_remove_all does not support type '{array_type}'.")
345        }
346    }
347}
348
349/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
350/// of `element_array[i]`.
351///
352/// The type of each **element** in `list_array` must be the same as the type of
353/// `element_array`. This function also handles nested arrays
354/// ([`arrow::array::ListArray`] of [`arrow::array::ListArray`]s)
355///
356/// For example, when called to remove a list array (where each element is a
357/// list of int32s, the second argument are int32 arrays, and the
358/// third argument is the number of occurrences to remove
359///
360/// ```text
361/// general_remove(
362///   [1, 2, 3, 2], 2, 1    ==> [1, 3, 2]   (only the first 2 is removed)
363///   [4, 5, 6, 5], 5, 2    ==> [4, 6]  (both 5s are removed)
364/// )
365/// ```
366fn general_remove<OffsetSize: OffsetSizeTrait>(
367    list_array: &GenericListArray<OffsetSize>,
368    element_array: &ArrayRef,
369    arr_n: &[i64],
370) -> Result<ArrayRef> {
371    let list_field = match list_array.data_type() {
372        DataType::List(field) | DataType::LargeList(field) => field,
373        _ => {
374            return exec_err!(
375                "Expected List or LargeList data type, got {:?}",
376                list_array.data_type()
377            );
378        }
379    };
380    let data_type = list_field.data_type();
381    let mut new_values = vec![];
382    // Build up the offsets for the final output array
383    let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
384    offsets.push(OffsetSize::zero());
385
386    // n is the number of elements to remove in this row
387    for (row_index, (list_array_row, n)) in
388        list_array.iter().zip(arr_n.iter()).enumerate()
389    {
390        match list_array_row {
391            Some(list_array_row) => {
392                let eq_array = utils::compare_element_to_list(
393                    &list_array_row,
394                    element_array,
395                    row_index,
396                    false,
397                )?;
398
399                // We need to keep at most first n elements as `false`, which represent the elements to remove.
400                let eq_array = if eq_array.false_count() < *n as usize {
401                    eq_array
402                } else {
403                    let mut count = 0;
404                    eq_array
405                        .iter()
406                        .map(|e| {
407                            // Keep first n `false` elements, and reverse other elements to `true`.
408                            if let Some(false) = e {
409                                if count < *n {
410                                    count += 1;
411                                    e
412                                } else {
413                                    Some(true)
414                                }
415                            } else {
416                                e
417                            }
418                        })
419                        .collect::<BooleanArray>()
420                };
421
422                let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?;
423                offsets.push(
424                    offsets[row_index] + OffsetSize::usize_as(filtered_array.len()),
425                );
426                new_values.push(filtered_array);
427            }
428            None => {
429                // Null element results in a null row (no new offsets)
430                offsets.push(offsets[row_index]);
431            }
432        }
433    }
434
435    let values = if new_values.is_empty() {
436        new_empty_array(data_type)
437    } else {
438        let new_values = new_values.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
439        arrow::compute::concat(&new_values)?
440    };
441
442    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
443        Arc::clone(list_field),
444        OffsetBuffer::new(offsets.into()),
445        values,
446        list_array.nulls().cloned(),
447    )?))
448}
449
450#[cfg(test)]
451mod tests {
452    use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
453    use arrow::array::{
454        Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
455    };
456    use arrow::datatypes::{DataType, Field, Int32Type};
457    use datafusion_common::ScalarValue;
458    use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
459    use datafusion_expr_common::columnar_value::ColumnarValue;
460    use std::ops::Deref;
461    use std::sync::Arc;
462
463    #[test]
464    fn test_array_remove_nullability() {
465        for nullability in [true, false] {
466            for item_nullability in [true, false] {
467                let input_field = Arc::new(Field::new(
468                    "num",
469                    DataType::new_list(DataType::Int32, item_nullability),
470                    nullability,
471                ));
472                let args_fields = vec![
473                    Arc::clone(&input_field),
474                    Arc::new(Field::new("a", DataType::Int32, false)),
475                ];
476                let scalar_args = vec![None, Some(&ScalarValue::Int32(Some(1)))];
477
478                let result = ArrayRemove::new()
479                    .return_field_from_args(ReturnFieldArgs {
480                        arg_fields: &args_fields,
481                        scalar_arguments: &scalar_args,
482                    })
483                    .unwrap();
484
485                assert_eq!(result, input_field);
486            }
487        }
488    }
489
490    #[test]
491    fn test_array_remove_n_nullability() {
492        for nullability in [true, false] {
493            for item_nullability in [true, false] {
494                let input_field = Arc::new(Field::new(
495                    "num",
496                    DataType::new_list(DataType::Int32, item_nullability),
497                    nullability,
498                ));
499                let args_fields = vec![
500                    Arc::clone(&input_field),
501                    Arc::new(Field::new("a", DataType::Int32, false)),
502                    Arc::new(Field::new("b", DataType::Int64, false)),
503                ];
504                let scalar_args = vec![
505                    None,
506                    Some(&ScalarValue::Int32(Some(1))),
507                    Some(&ScalarValue::Int64(Some(1))),
508                ];
509
510                let result = ArrayRemoveN::new()
511                    .return_field_from_args(ReturnFieldArgs {
512                        arg_fields: &args_fields,
513                        scalar_arguments: &scalar_args,
514                    })
515                    .unwrap();
516
517                assert_eq!(result, input_field);
518            }
519        }
520    }
521
522    #[test]
523    fn test_array_remove_all_nullability() {
524        for nullability in [true, false] {
525            for item_nullability in [true, false] {
526                let input_field = Arc::new(Field::new(
527                    "num",
528                    DataType::new_list(DataType::Int32, item_nullability),
529                    nullability,
530                ));
531                let result = ArrayRemoveAll::new()
532                    .return_field_from_args(ReturnFieldArgs {
533                        arg_fields: &[Arc::clone(&input_field)],
534                        scalar_arguments: &[None],
535                    })
536                    .unwrap();
537
538                assert_eq!(result, input_field);
539            }
540        }
541    }
542
543    fn ensure_field_nullability<O: OffsetSizeTrait>(
544        field_nullable: bool,
545        list: GenericListArray<O>,
546    ) -> GenericListArray<O> {
547        let (field, offsets, values, nulls) = list.into_parts();
548
549        if field.is_nullable() == field_nullable {
550            return GenericListArray::new(field, offsets, values, nulls);
551        }
552        if !field_nullable {
553            assert_eq!(nulls, None);
554        }
555
556        let field = Arc::new(field.deref().clone().with_nullable(field_nullable));
557
558        GenericListArray::new(field, offsets, values, nulls)
559    }
560
561    #[test]
562    fn test_array_remove_non_nullable() {
563        let input_list = Arc::new(ensure_field_nullability(
564            false,
565            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
566                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
567                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
568            ]),
569        ));
570        let expected_list = ensure_field_nullability(
571            false,
572            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
573                Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
574                Some(([42, 55, 63, 2]).iter().copied().map(Some)),
575            ]),
576        );
577
578        let element_to_remove = ScalarValue::Int32(Some(2));
579
580        assert_array_remove(input_list, expected_list, element_to_remove);
581    }
582
583    #[test]
584    fn test_array_remove_nullable() {
585        let input_list = Arc::new(ensure_field_nullability(
586            true,
587            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
588                Some(vec![
589                    Some(1),
590                    Some(2),
591                    Some(2),
592                    Some(3),
593                    None,
594                    Some(1),
595                    Some(4),
596                ]),
597                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
598            ]),
599        ));
600        let expected_list = ensure_field_nullability(
601            true,
602            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
603                Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
604                Some(vec![Some(42), None, Some(63), Some(2)]),
605            ]),
606        );
607
608        let element_to_remove = ScalarValue::Int32(Some(2));
609
610        assert_array_remove(input_list, expected_list, element_to_remove);
611    }
612
613    fn assert_array_remove(
614        input_list: ArrayRef,
615        expected_list: GenericListArray<i32>,
616        element_to_remove: ScalarValue,
617    ) {
618        assert_eq!(input_list.data_type(), expected_list.data_type());
619        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
620        let input_list_len = input_list.len();
621        let input_list_data_type = input_list.data_type().clone();
622
623        let udf = ArrayRemove::new();
624        let args_fields = vec![
625            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
626            Arc::new(Field::new(
627                "el",
628                element_to_remove.data_type(),
629                element_to_remove.is_null(),
630            )),
631        ];
632        let scalar_args = vec![None, Some(&element_to_remove)];
633
634        let return_field = udf
635            .return_field_from_args(ReturnFieldArgs {
636                arg_fields: &args_fields,
637                scalar_arguments: &scalar_args,
638            })
639            .unwrap();
640
641        let result = udf
642            .invoke_with_args(ScalarFunctionArgs {
643                args: vec![
644                    ColumnarValue::Array(input_list),
645                    ColumnarValue::Scalar(element_to_remove),
646                ],
647                arg_fields: args_fields,
648                number_rows: input_list_len,
649                return_field,
650                config_options: Arc::new(Default::default()),
651            })
652            .unwrap();
653
654        assert_eq!(result.data_type(), input_list_data_type);
655        match result {
656            ColumnarValue::Array(array) => {
657                let result_list = array.as_list::<i32>();
658                assert_eq!(result_list, &expected_list);
659            }
660            _ => panic!("Expected ColumnarValue::Array"),
661        }
662    }
663
664    #[test]
665    fn test_array_remove_n_non_nullable() {
666        let input_list = Arc::new(ensure_field_nullability(
667            false,
668            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
669                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
670                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
671            ]),
672        ));
673        let expected_list = ensure_field_nullability(
674            false,
675            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
676                Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)),
677                Some(([42, 55, 63]).iter().copied().map(Some)),
678            ]),
679        );
680
681        let element_to_remove = ScalarValue::Int32(Some(2));
682
683        assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
684    }
685
686    #[test]
687    fn test_array_remove_n_nullable() {
688        let input_list = Arc::new(ensure_field_nullability(
689            true,
690            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
691                Some(vec![
692                    Some(1),
693                    Some(2),
694                    Some(2),
695                    Some(3),
696                    None,
697                    Some(1),
698                    Some(4),
699                ]),
700                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
701            ]),
702        ));
703        let expected_list = ensure_field_nullability(
704            true,
705            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
706                Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
707                Some(vec![Some(42), None, Some(63)]),
708            ]),
709        );
710
711        let element_to_remove = ScalarValue::Int32(Some(2));
712
713        assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
714    }
715
716    fn assert_array_remove_n(
717        input_list: ArrayRef,
718        expected_list: GenericListArray<i32>,
719        element_to_remove: ScalarValue,
720        n: i64,
721    ) {
722        assert_eq!(input_list.data_type(), expected_list.data_type());
723        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
724        let input_list_len = input_list.len();
725        let input_list_data_type = input_list.data_type().clone();
726
727        let count_scalar = ScalarValue::Int64(Some(n));
728
729        let udf = ArrayRemoveN::new();
730        let args_fields = vec![
731            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
732            Arc::new(Field::new(
733                "el",
734                element_to_remove.data_type(),
735                element_to_remove.is_null(),
736            )),
737            Arc::new(Field::new("count", DataType::Int64, false)),
738        ];
739        let scalar_args = vec![None, Some(&element_to_remove), Some(&count_scalar)];
740
741        let return_field = udf
742            .return_field_from_args(ReturnFieldArgs {
743                arg_fields: &args_fields,
744                scalar_arguments: &scalar_args,
745            })
746            .unwrap();
747
748        let result = udf
749            .invoke_with_args(ScalarFunctionArgs {
750                args: vec![
751                    ColumnarValue::Array(input_list),
752                    ColumnarValue::Scalar(element_to_remove),
753                    ColumnarValue::Scalar(count_scalar),
754                ],
755                arg_fields: args_fields,
756                number_rows: input_list_len,
757                return_field,
758                config_options: Arc::new(Default::default()),
759            })
760            .unwrap();
761
762        assert_eq!(result.data_type(), input_list_data_type);
763        match result {
764            ColumnarValue::Array(array) => {
765                let result_list = array.as_list::<i32>();
766                assert_eq!(result_list, &expected_list);
767            }
768            _ => panic!("Expected ColumnarValue::Array"),
769        }
770    }
771
772    #[test]
773    fn test_array_remove_all_non_nullable() {
774        let input_list = Arc::new(ensure_field_nullability(
775            false,
776            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
777                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
778                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
779            ]),
780        ));
781        let expected_list = ensure_field_nullability(
782            false,
783            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
784                Some(([1, 3, 1, 4]).iter().copied().map(Some)),
785                Some(([42, 55, 63]).iter().copied().map(Some)),
786            ]),
787        );
788
789        let element_to_remove = ScalarValue::Int32(Some(2));
790
791        assert_array_remove_all(input_list, expected_list, element_to_remove);
792    }
793
794    #[test]
795    fn test_array_remove_all_nullable() {
796        let input_list = Arc::new(ensure_field_nullability(
797            true,
798            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
799                Some(vec![
800                    Some(1),
801                    Some(2),
802                    Some(2),
803                    Some(3),
804                    None,
805                    Some(1),
806                    Some(4),
807                ]),
808                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
809            ]),
810        ));
811        let expected_list = ensure_field_nullability(
812            true,
813            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
814                Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
815                Some(vec![Some(42), None, Some(63)]),
816            ]),
817        );
818
819        let element_to_remove = ScalarValue::Int32(Some(2));
820
821        assert_array_remove_all(input_list, expected_list, element_to_remove);
822    }
823
824    fn assert_array_remove_all(
825        input_list: ArrayRef,
826        expected_list: GenericListArray<i32>,
827        element_to_remove: ScalarValue,
828    ) {
829        assert_eq!(input_list.data_type(), expected_list.data_type());
830        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
831        let input_list_len = input_list.len();
832        let input_list_data_type = input_list.data_type().clone();
833
834        let udf = ArrayRemoveAll::new();
835        let args_fields = vec![
836            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
837            Arc::new(Field::new(
838                "el",
839                element_to_remove.data_type(),
840                element_to_remove.is_null(),
841            )),
842        ];
843        let scalar_args = vec![None, Some(&element_to_remove)];
844
845        let return_field = udf
846            .return_field_from_args(ReturnFieldArgs {
847                arg_fields: &args_fields,
848                scalar_arguments: &scalar_args,
849            })
850            .unwrap();
851
852        let result = udf
853            .invoke_with_args(ScalarFunctionArgs {
854                args: vec![
855                    ColumnarValue::Array(input_list),
856                    ColumnarValue::Scalar(element_to_remove),
857                ],
858                arg_fields: args_fields,
859                number_rows: input_list_len,
860                return_field,
861                config_options: Arc::new(Default::default()),
862            })
863            .unwrap();
864
865        assert_eq!(result.data_type(), input_list_data_type);
866        match result {
867            ColumnarValue::Array(array) => {
868                let result_list = array.as_list::<i32>();
869                assert_eq!(result_list, &expected_list);
870            }
871            _ => panic!("Expected ColumnarValue::Array"),
872        }
873    }
874}