Skip to main content

datafusion_functions_nested/
remove.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions.
19
20use crate::utils;
21use crate::utils::make_scalar_function;
22use arrow::array::{
23    Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder,
24    OffsetSizeTrait, cast::AsArray, make_array,
25};
26use arrow::buffer::OffsetBuffer;
27use arrow::datatypes::{DataType, FieldRef};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
31use datafusion_expr::{
32    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33    ScalarUDFImpl, Signature, TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::any::Any;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40    ArrayRemove,
41    array_remove,
42    array element,
43    "removes the first element from the array equal to the given value.",
44    array_remove_udf
45);
46
47#[user_doc(
48    doc_section(label = "Array Functions"),
49    description = "Removes the first element from the array equal to the given value.",
50    syntax_example = "array_remove(array, element)",
51    sql_example = r#"```sql
52> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
53+----------------------------------------------+
54| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
55+----------------------------------------------+
56| [1, 2, 3, 2, 1, 4]                           |
57+----------------------------------------------+
58```"#,
59    argument(
60        name = "array",
61        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62    ),
63    argument(
64        name = "element",
65        description = "Element to be removed from the array."
66    )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct ArrayRemove {
70    signature: Signature,
71    aliases: Vec<String>,
72}
73
74impl Default for ArrayRemove {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl ArrayRemove {
81    pub fn new() -> Self {
82        Self {
83            signature: Signature::array_and_element(Volatility::Immutable),
84            aliases: vec!["list_remove".to_string()],
85        }
86    }
87}
88
89impl ScalarUDFImpl for ArrayRemove {
90    fn as_any(&self) -> &dyn Any {
91        self
92    }
93
94    fn name(&self) -> &str {
95        "array_remove"
96    }
97
98    fn signature(&self) -> &Signature {
99        &self.signature
100    }
101
102    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
103        internal_err!("return_field_from_args should be used instead")
104    }
105
106    fn return_field_from_args(
107        &self,
108        args: datafusion_expr::ReturnFieldArgs,
109    ) -> Result<FieldRef> {
110        Ok(Arc::clone(&args.arg_fields[0]))
111    }
112
113    fn invoke_with_args(
114        &self,
115        args: datafusion_expr::ScalarFunctionArgs,
116    ) -> Result<ColumnarValue> {
117        make_scalar_function(array_remove_inner)(&args.args)
118    }
119
120    fn aliases(&self) -> &[String] {
121        &self.aliases
122    }
123
124    fn documentation(&self) -> Option<&Documentation> {
125        self.doc()
126    }
127}
128
129make_udf_expr_and_func!(
130    ArrayRemoveN,
131    array_remove_n,
132    array element max,
133    "removes the first `max` elements from the array equal to the given value.",
134    array_remove_n_udf
135);
136
137#[user_doc(
138    doc_section(label = "Array Functions"),
139    description = "Removes the first `max` elements from the array equal to the given value.",
140    syntax_example = "array_remove_n(array, element, max))",
141    sql_example = r#"```sql
142> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
143+---------------------------------------------------------+
144| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
145+---------------------------------------------------------+
146| [1, 3, 2, 1, 4]                                         |
147+---------------------------------------------------------+
148```"#,
149    argument(
150        name = "array",
151        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
152    ),
153    argument(
154        name = "element",
155        description = "Element to be removed from the array."
156    ),
157    argument(name = "max", description = "Number of first occurrences to remove.")
158)]
159#[derive(Debug, PartialEq, Eq, Hash)]
160pub(super) struct ArrayRemoveN {
161    signature: Signature,
162    aliases: Vec<String>,
163}
164
165impl ArrayRemoveN {
166    pub fn new() -> Self {
167        Self {
168            signature: Signature::new(
169                TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
170                    arguments: vec![
171                        ArrayFunctionArgument::Array,
172                        ArrayFunctionArgument::Element,
173                        ArrayFunctionArgument::Index,
174                    ],
175                    array_coercion: Some(ListCoercion::FixedSizedListToList),
176                }),
177                Volatility::Immutable,
178            ),
179            aliases: vec!["list_remove_n".to_string()],
180        }
181    }
182}
183
184impl ScalarUDFImpl for ArrayRemoveN {
185    fn as_any(&self) -> &dyn Any {
186        self
187    }
188
189    fn name(&self) -> &str {
190        "array_remove_n"
191    }
192
193    fn signature(&self) -> &Signature {
194        &self.signature
195    }
196
197    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
198        internal_err!("return_field_from_args should be used instead")
199    }
200
201    fn return_field_from_args(
202        &self,
203        args: datafusion_expr::ReturnFieldArgs,
204    ) -> Result<FieldRef> {
205        Ok(Arc::clone(&args.arg_fields[0]))
206    }
207
208    fn invoke_with_args(
209        &self,
210        args: datafusion_expr::ScalarFunctionArgs,
211    ) -> Result<ColumnarValue> {
212        make_scalar_function(array_remove_n_inner)(&args.args)
213    }
214
215    fn aliases(&self) -> &[String] {
216        &self.aliases
217    }
218
219    fn documentation(&self) -> Option<&Documentation> {
220        self.doc()
221    }
222}
223
224make_udf_expr_and_func!(
225    ArrayRemoveAll,
226    array_remove_all,
227    array element,
228    "removes all elements from the array equal to the given value.",
229    array_remove_all_udf
230);
231
232#[user_doc(
233    doc_section(label = "Array Functions"),
234    description = "Removes all elements from the array equal to the given value.",
235    syntax_example = "array_remove_all(array, element)",
236    sql_example = r#"```sql
237> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
238+--------------------------------------------------+
239| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
240+--------------------------------------------------+
241| [1, 3, 1, 4]                                     |
242+--------------------------------------------------+
243```"#,
244    argument(
245        name = "array",
246        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
247    ),
248    argument(
249        name = "element",
250        description = "Element to be removed from the array."
251    )
252)]
253#[derive(Debug, PartialEq, Eq, Hash)]
254pub(super) struct ArrayRemoveAll {
255    signature: Signature,
256    aliases: Vec<String>,
257}
258
259impl ArrayRemoveAll {
260    pub fn new() -> Self {
261        Self {
262            signature: Signature::array_and_element(Volatility::Immutable),
263            aliases: vec!["list_remove_all".to_string()],
264        }
265    }
266}
267
268impl ScalarUDFImpl for ArrayRemoveAll {
269    fn as_any(&self) -> &dyn Any {
270        self
271    }
272
273    fn name(&self) -> &str {
274        "array_remove_all"
275    }
276
277    fn signature(&self) -> &Signature {
278        &self.signature
279    }
280
281    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
282        internal_err!("return_field_from_args should be used instead")
283    }
284
285    fn return_field_from_args(
286        &self,
287        args: datafusion_expr::ReturnFieldArgs,
288    ) -> Result<FieldRef> {
289        Ok(Arc::clone(&args.arg_fields[0]))
290    }
291
292    fn invoke_with_args(
293        &self,
294        args: datafusion_expr::ScalarFunctionArgs,
295    ) -> Result<ColumnarValue> {
296        make_scalar_function(array_remove_all_inner)(&args.args)
297    }
298
299    fn aliases(&self) -> &[String] {
300        &self.aliases
301    }
302
303    fn documentation(&self) -> Option<&Documentation> {
304        self.doc()
305    }
306}
307
308fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
309    let [array, element] = take_function_args("array_remove", args)?;
310
311    let arr_n = vec![1; array.len()];
312    array_remove_internal(array, element, &arr_n)
313}
314
315fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
316    let [array, element, max] = take_function_args("array_remove_n", args)?;
317
318    let arr_n = as_int64_array(max)?.values().to_vec();
319    array_remove_internal(array, element, &arr_n)
320}
321
322fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
323    let [array, element] = take_function_args("array_remove_all", args)?;
324
325    let arr_n = vec![i64::MAX; array.len()];
326    array_remove_internal(array, element, &arr_n)
327}
328
329fn array_remove_internal(
330    array: &ArrayRef,
331    element_array: &ArrayRef,
332    arr_n: &[i64],
333) -> Result<ArrayRef> {
334    match array.data_type() {
335        DataType::List(_) => {
336            let list_array = array.as_list::<i32>();
337            general_remove::<i32>(list_array, element_array, arr_n)
338        }
339        DataType::LargeList(_) => {
340            let list_array = array.as_list::<i64>();
341            general_remove::<i64>(list_array, element_array, arr_n)
342        }
343        array_type => {
344            exec_err!("array_remove_all does not support type '{array_type}'.")
345        }
346    }
347}
348
349/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
350/// of `element_array[i]`.
351///
352/// The type of each **element** in `list_array` must be the same as the type of
353/// `element_array`. This function also handles nested arrays
354/// ([`arrow::array::ListArray`] of [`arrow::array::ListArray`]s)
355///
356/// For example, when called to remove a list array (where each element is a
357/// list of int32s, the second argument are int32 arrays, and the
358/// third argument is the number of occurrences to remove
359///
360/// ```text
361/// general_remove(
362///   [1, 2, 3, 2], 2, 1    ==> [1, 3, 2]   (only the first 2 is removed)
363///   [4, 5, 6, 5], 5, 2    ==> [4, 6]  (both 5s are removed)
364/// )
365/// ```
366fn general_remove<OffsetSize: OffsetSizeTrait>(
367    list_array: &GenericListArray<OffsetSize>,
368    element_array: &ArrayRef,
369    arr_n: &[i64],
370) -> Result<ArrayRef> {
371    let list_field = match list_array.data_type() {
372        DataType::List(field) | DataType::LargeList(field) => field,
373        _ => {
374            return exec_err!(
375                "Expected List or LargeList data type, got {:?}",
376                list_array.data_type()
377            );
378        }
379    };
380    let original_data = list_array.values().to_data();
381    // Build up the offsets for the final output array
382    let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
383    offsets.push(OffsetSize::zero());
384
385    let mut mutable = MutableArrayData::with_capacities(
386        vec![&original_data],
387        false,
388        Capacities::Array(original_data.len()),
389    );
390    let mut valid = NullBufferBuilder::new(list_array.len());
391
392    for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
393        if list_array.is_null(row_index) || element_array.is_null(row_index) {
394            offsets.push(offsets[row_index]);
395            valid.append_null();
396            continue;
397        }
398
399        let start = offset_window[0].to_usize().unwrap();
400        let end = offset_window[1].to_usize().unwrap();
401        // n is the number of elements to remove in this row
402        let n = arr_n[row_index];
403
404        // compare each element in the list, `false` means the element matches and should be removed
405        let eq_array = utils::compare_element_to_list(
406            &list_array.value(row_index),
407            element_array,
408            row_index,
409            false,
410        )?;
411
412        let num_to_remove = eq_array.false_count();
413
414        // Fast path: no elements to remove, copy entire row
415        if num_to_remove == 0 {
416            mutable.extend(0, start, end);
417            offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start));
418            valid.append_non_null();
419            continue;
420        }
421
422        // Remove at most `n` matching elements
423        let max_removals = n.min(num_to_remove as i64);
424        let mut removed = 0i64;
425        let mut copied = 0usize;
426        // marks the beginning of a range of elements pending to be copied.
427        let mut pending_batch_to_retain: Option<usize> = None;
428        for (i, keep) in eq_array.iter().enumerate() {
429            if keep == Some(false) && removed < max_removals {
430                // Flush pending batch before skipping this element
431                if let Some(bs) = pending_batch_to_retain {
432                    mutable.extend(0, start + bs, start + i);
433                    copied += i - bs;
434                    pending_batch_to_retain = None;
435                }
436                removed += 1;
437            } else if pending_batch_to_retain.is_none() {
438                pending_batch_to_retain = Some(i);
439            }
440        }
441
442        // Flush remaining batch
443        if let Some(bs) = pending_batch_to_retain {
444            mutable.extend(0, start + bs, start + eq_array.len());
445            copied += eq_array.len() - bs;
446        }
447
448        offsets.push(offsets[row_index] + OffsetSize::usize_as(copied));
449        valid.append_non_null();
450    }
451
452    let new_values = make_array(mutable.freeze());
453    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
454        Arc::clone(list_field),
455        OffsetBuffer::new(offsets.into()),
456        new_values,
457        valid.finish(),
458    )?))
459}
460
461#[cfg(test)]
462mod tests {
463    use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
464    use arrow::array::{
465        Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
466    };
467    use arrow::datatypes::{DataType, Field, Int32Type};
468    use datafusion_common::ScalarValue;
469    use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
470    use datafusion_expr_common::columnar_value::ColumnarValue;
471    use std::ops::Deref;
472    use std::sync::Arc;
473
474    #[test]
475    fn test_array_remove_nullability() {
476        for nullability in [true, false] {
477            for item_nullability in [true, false] {
478                let input_field = Arc::new(Field::new(
479                    "num",
480                    DataType::new_list(DataType::Int32, item_nullability),
481                    nullability,
482                ));
483                let args_fields = vec![
484                    Arc::clone(&input_field),
485                    Arc::new(Field::new("a", DataType::Int32, false)),
486                ];
487                let scalar_args = vec![None, Some(&ScalarValue::Int32(Some(1)))];
488
489                let result = ArrayRemove::new()
490                    .return_field_from_args(ReturnFieldArgs {
491                        arg_fields: &args_fields,
492                        scalar_arguments: &scalar_args,
493                    })
494                    .unwrap();
495
496                assert_eq!(result, input_field);
497            }
498        }
499    }
500
501    #[test]
502    fn test_array_remove_n_nullability() {
503        for nullability in [true, false] {
504            for item_nullability in [true, false] {
505                let input_field = Arc::new(Field::new(
506                    "num",
507                    DataType::new_list(DataType::Int32, item_nullability),
508                    nullability,
509                ));
510                let args_fields = vec![
511                    Arc::clone(&input_field),
512                    Arc::new(Field::new("a", DataType::Int32, false)),
513                    Arc::new(Field::new("b", DataType::Int64, false)),
514                ];
515                let scalar_args = vec![
516                    None,
517                    Some(&ScalarValue::Int32(Some(1))),
518                    Some(&ScalarValue::Int64(Some(1))),
519                ];
520
521                let result = ArrayRemoveN::new()
522                    .return_field_from_args(ReturnFieldArgs {
523                        arg_fields: &args_fields,
524                        scalar_arguments: &scalar_args,
525                    })
526                    .unwrap();
527
528                assert_eq!(result, input_field);
529            }
530        }
531    }
532
533    #[test]
534    fn test_array_remove_all_nullability() {
535        for nullability in [true, false] {
536            for item_nullability in [true, false] {
537                let input_field = Arc::new(Field::new(
538                    "num",
539                    DataType::new_list(DataType::Int32, item_nullability),
540                    nullability,
541                ));
542                let result = ArrayRemoveAll::new()
543                    .return_field_from_args(ReturnFieldArgs {
544                        arg_fields: &[Arc::clone(&input_field)],
545                        scalar_arguments: &[None],
546                    })
547                    .unwrap();
548
549                assert_eq!(result, input_field);
550            }
551        }
552    }
553
554    fn ensure_field_nullability<O: OffsetSizeTrait>(
555        field_nullable: bool,
556        list: GenericListArray<O>,
557    ) -> GenericListArray<O> {
558        let (field, offsets, values, nulls) = list.into_parts();
559
560        if field.is_nullable() == field_nullable {
561            return GenericListArray::new(field, offsets, values, nulls);
562        }
563        if !field_nullable {
564            assert_eq!(nulls, None);
565        }
566
567        let field = Arc::new(field.deref().clone().with_nullable(field_nullable));
568
569        GenericListArray::new(field, offsets, values, nulls)
570    }
571
572    #[test]
573    fn test_array_remove_non_nullable() {
574        let input_list = Arc::new(ensure_field_nullability(
575            false,
576            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
577                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
578                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
579            ]),
580        ));
581        let expected_list = ensure_field_nullability(
582            false,
583            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
584                Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
585                Some(([42, 55, 63, 2]).iter().copied().map(Some)),
586            ]),
587        );
588
589        let element_to_remove = ScalarValue::Int32(Some(2));
590
591        assert_array_remove(input_list, expected_list, element_to_remove);
592    }
593
594    #[test]
595    fn test_array_remove_nullable() {
596        let input_list = Arc::new(ensure_field_nullability(
597            true,
598            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
599                Some(vec![
600                    Some(1),
601                    Some(2),
602                    Some(2),
603                    Some(3),
604                    None,
605                    Some(1),
606                    Some(4),
607                ]),
608                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
609            ]),
610        ));
611        let expected_list = ensure_field_nullability(
612            true,
613            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
614                Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
615                Some(vec![Some(42), None, Some(63), Some(2)]),
616            ]),
617        );
618
619        let element_to_remove = ScalarValue::Int32(Some(2));
620
621        assert_array_remove(input_list, expected_list, element_to_remove);
622    }
623
624    fn assert_array_remove(
625        input_list: ArrayRef,
626        expected_list: GenericListArray<i32>,
627        element_to_remove: ScalarValue,
628    ) {
629        assert_eq!(input_list.data_type(), expected_list.data_type());
630        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
631        let input_list_len = input_list.len();
632        let input_list_data_type = input_list.data_type().clone();
633
634        let udf = ArrayRemove::new();
635        let args_fields = vec![
636            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
637            Arc::new(Field::new(
638                "el",
639                element_to_remove.data_type(),
640                element_to_remove.is_null(),
641            )),
642        ];
643        let scalar_args = vec![None, Some(&element_to_remove)];
644
645        let return_field = udf
646            .return_field_from_args(ReturnFieldArgs {
647                arg_fields: &args_fields,
648                scalar_arguments: &scalar_args,
649            })
650            .unwrap();
651
652        let result = udf
653            .invoke_with_args(ScalarFunctionArgs {
654                args: vec![
655                    ColumnarValue::Array(input_list),
656                    ColumnarValue::Scalar(element_to_remove),
657                ],
658                arg_fields: args_fields,
659                number_rows: input_list_len,
660                return_field,
661                config_options: Arc::new(Default::default()),
662            })
663            .unwrap();
664
665        assert_eq!(result.data_type(), input_list_data_type);
666        match result {
667            ColumnarValue::Array(array) => {
668                let result_list = array.as_list::<i32>();
669                assert_eq!(result_list, &expected_list);
670            }
671            _ => panic!("Expected ColumnarValue::Array"),
672        }
673    }
674
675    #[test]
676    fn test_array_remove_n_non_nullable() {
677        let input_list = Arc::new(ensure_field_nullability(
678            false,
679            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
680                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
681                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
682            ]),
683        ));
684        let expected_list = ensure_field_nullability(
685            false,
686            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
687                Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)),
688                Some(([42, 55, 63]).iter().copied().map(Some)),
689            ]),
690        );
691
692        let element_to_remove = ScalarValue::Int32(Some(2));
693
694        assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
695    }
696
697    #[test]
698    fn test_array_remove_n_nullable() {
699        let input_list = Arc::new(ensure_field_nullability(
700            true,
701            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
702                Some(vec![
703                    Some(1),
704                    Some(2),
705                    Some(2),
706                    Some(3),
707                    None,
708                    Some(1),
709                    Some(4),
710                ]),
711                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
712            ]),
713        ));
714        let expected_list = ensure_field_nullability(
715            true,
716            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
717                Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
718                Some(vec![Some(42), None, Some(63)]),
719            ]),
720        );
721
722        let element_to_remove = ScalarValue::Int32(Some(2));
723
724        assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
725    }
726
727    fn assert_array_remove_n(
728        input_list: ArrayRef,
729        expected_list: GenericListArray<i32>,
730        element_to_remove: ScalarValue,
731        n: i64,
732    ) {
733        assert_eq!(input_list.data_type(), expected_list.data_type());
734        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
735        let input_list_len = input_list.len();
736        let input_list_data_type = input_list.data_type().clone();
737
738        let count_scalar = ScalarValue::Int64(Some(n));
739
740        let udf = ArrayRemoveN::new();
741        let args_fields = vec![
742            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
743            Arc::new(Field::new(
744                "el",
745                element_to_remove.data_type(),
746                element_to_remove.is_null(),
747            )),
748            Arc::new(Field::new("count", DataType::Int64, false)),
749        ];
750        let scalar_args = vec![None, Some(&element_to_remove), Some(&count_scalar)];
751
752        let return_field = udf
753            .return_field_from_args(ReturnFieldArgs {
754                arg_fields: &args_fields,
755                scalar_arguments: &scalar_args,
756            })
757            .unwrap();
758
759        let result = udf
760            .invoke_with_args(ScalarFunctionArgs {
761                args: vec![
762                    ColumnarValue::Array(input_list),
763                    ColumnarValue::Scalar(element_to_remove),
764                    ColumnarValue::Scalar(count_scalar),
765                ],
766                arg_fields: args_fields,
767                number_rows: input_list_len,
768                return_field,
769                config_options: Arc::new(Default::default()),
770            })
771            .unwrap();
772
773        assert_eq!(result.data_type(), input_list_data_type);
774        match result {
775            ColumnarValue::Array(array) => {
776                let result_list = array.as_list::<i32>();
777                assert_eq!(result_list, &expected_list);
778            }
779            _ => panic!("Expected ColumnarValue::Array"),
780        }
781    }
782
783    #[test]
784    fn test_array_remove_all_non_nullable() {
785        let input_list = Arc::new(ensure_field_nullability(
786            false,
787            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
788                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
789                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
790            ]),
791        ));
792        let expected_list = ensure_field_nullability(
793            false,
794            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
795                Some(([1, 3, 1, 4]).iter().copied().map(Some)),
796                Some(([42, 55, 63]).iter().copied().map(Some)),
797            ]),
798        );
799
800        let element_to_remove = ScalarValue::Int32(Some(2));
801
802        assert_array_remove_all(input_list, expected_list, element_to_remove);
803    }
804
805    #[test]
806    fn test_array_remove_all_nullable() {
807        let input_list = Arc::new(ensure_field_nullability(
808            true,
809            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
810                Some(vec![
811                    Some(1),
812                    Some(2),
813                    Some(2),
814                    Some(3),
815                    None,
816                    Some(1),
817                    Some(4),
818                ]),
819                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
820            ]),
821        ));
822        let expected_list = ensure_field_nullability(
823            true,
824            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
825                Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
826                Some(vec![Some(42), None, Some(63)]),
827            ]),
828        );
829
830        let element_to_remove = ScalarValue::Int32(Some(2));
831
832        assert_array_remove_all(input_list, expected_list, element_to_remove);
833    }
834
835    fn assert_array_remove_all(
836        input_list: ArrayRef,
837        expected_list: GenericListArray<i32>,
838        element_to_remove: ScalarValue,
839    ) {
840        assert_eq!(input_list.data_type(), expected_list.data_type());
841        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
842        let input_list_len = input_list.len();
843        let input_list_data_type = input_list.data_type().clone();
844
845        let udf = ArrayRemoveAll::new();
846        let args_fields = vec![
847            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
848            Arc::new(Field::new(
849                "el",
850                element_to_remove.data_type(),
851                element_to_remove.is_null(),
852            )),
853        ];
854        let scalar_args = vec![None, Some(&element_to_remove)];
855
856        let return_field = udf
857            .return_field_from_args(ReturnFieldArgs {
858                arg_fields: &args_fields,
859                scalar_arguments: &scalar_args,
860            })
861            .unwrap();
862
863        let result = udf
864            .invoke_with_args(ScalarFunctionArgs {
865                args: vec![
866                    ColumnarValue::Array(input_list),
867                    ColumnarValue::Scalar(element_to_remove),
868                ],
869                arg_fields: args_fields,
870                number_rows: input_list_len,
871                return_field,
872                config_options: Arc::new(Default::default()),
873            })
874            .unwrap();
875
876        assert_eq!(result.data_type(), input_list_data_type);
877        match result {
878            ColumnarValue::Array(array) => {
879                let result_list = array.as_list::<i32>();
880                assert_eq!(result_list, &expected_list);
881            }
882            _ => panic!("Expected ColumnarValue::Array"),
883        }
884    }
885}