Skip to main content

datafusion_functions_nested/
remove.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions.
19
20use crate::utils;
21use crate::utils::make_scalar_function;
22use arrow::array::{
23    Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait,
24    cast::AsArray, make_array,
25};
26use arrow::buffer::{NullBuffer, OffsetBuffer};
27use arrow::datatypes::{DataType, FieldRef};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
31use datafusion_expr::{
32    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33    ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::sync::Arc;
37
38make_udf_expr_and_func!(
39    ArrayRemove,
40    array_remove,
41    array element,
42    "removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
43    array_remove_udf
44);
45
46#[user_doc(
47    doc_section(label = "Array Functions"),
48    description = "Removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
49    syntax_example = "array_remove(array, element)",
50    sql_example = r#"```sql
51> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
52+----------------------------------------------+
53| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
54+----------------------------------------------+
55| [1, 2, 3, 2, 1, 4]                           |
56+----------------------------------------------+
57
58> select array_remove([1, 2, NULL, 2, 4], 2);
59+---------------------------------------------------+
60| array_remove(List([1,2,NULL,2,4]),Int64(2)) |
61+---------------------------------------------------+
62| [1, NULL, 2, 4]                              |
63+---------------------------------------------------+
64```"#,
65    argument(
66        name = "array",
67        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
68    ),
69    argument(
70        name = "element",
71        description = "Element to be removed from the array."
72    )
73)]
74#[derive(Debug, PartialEq, Eq, Hash)]
75pub struct ArrayRemove {
76    signature: Signature,
77    aliases: Vec<String>,
78}
79
80impl Default for ArrayRemove {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl ArrayRemove {
87    pub fn new() -> Self {
88        Self {
89            signature: Signature::array_and_element(Volatility::Immutable),
90            aliases: vec!["list_remove".to_string()],
91        }
92    }
93}
94
95impl ScalarUDFImpl for ArrayRemove {
96    fn name(&self) -> &str {
97        "array_remove"
98    }
99
100    fn signature(&self) -> &Signature {
101        &self.signature
102    }
103
104    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
105        internal_err!("return_field_from_args should be used instead")
106    }
107
108    fn return_field_from_args(
109        &self,
110        args: datafusion_expr::ReturnFieldArgs,
111    ) -> Result<FieldRef> {
112        Ok(Arc::clone(&args.arg_fields[0]))
113    }
114
115    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
116        make_scalar_function(array_remove_inner)(&args.args)
117    }
118
119    fn aliases(&self) -> &[String] {
120        &self.aliases
121    }
122
123    fn documentation(&self) -> Option<&Documentation> {
124        self.doc()
125    }
126}
127
128make_udf_expr_and_func!(
129    ArrayRemoveN,
130    array_remove_n,
131    array element max,
132    "removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
133    array_remove_n_udf
134);
135
136#[user_doc(
137    doc_section(label = "Array Functions"),
138    description = "Removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
139    syntax_example = "array_remove_n(array, element, max)",
140    sql_example = r#"```sql
141> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
142+---------------------------------------------------------+
143| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
144+---------------------------------------------------------+
145| [1, 3, 2, 1, 4]                                         |
146+---------------------------------------------------------+
147
148> select array_remove_n([1, 2, NULL, 2, 4], 2, 2);
149+----------------------------------------------------------+
150| array_remove_n(List([1,2,NULL,2,4]),Int64(2),Int64(2)) |
151+----------------------------------------------------------+
152| [1, NULL, 4]                                            |
153+----------------------------------------------------------+
154```"#,
155    argument(
156        name = "array",
157        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
158    ),
159    argument(
160        name = "element",
161        description = "Element to be removed from the array."
162    ),
163    argument(name = "max", description = "Number of first occurrences to remove.")
164)]
165#[derive(Debug, PartialEq, Eq, Hash)]
166pub struct ArrayRemoveN {
167    signature: Signature,
168    aliases: Vec<String>,
169}
170
171impl Default for ArrayRemoveN {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177impl ArrayRemoveN {
178    pub fn new() -> Self {
179        Self {
180            signature: Signature::new(
181                TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
182                    arguments: vec![
183                        ArrayFunctionArgument::Array,
184                        ArrayFunctionArgument::Element,
185                        ArrayFunctionArgument::Index,
186                    ],
187                    array_coercion: Some(ListCoercion::FixedSizedListToList),
188                }),
189                Volatility::Immutable,
190            ),
191            aliases: vec!["list_remove_n".to_string()],
192        }
193    }
194}
195
196impl ScalarUDFImpl for ArrayRemoveN {
197    fn name(&self) -> &str {
198        "array_remove_n"
199    }
200
201    fn signature(&self) -> &Signature {
202        &self.signature
203    }
204
205    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
206        internal_err!("return_field_from_args should be used instead")
207    }
208
209    fn return_field_from_args(
210        &self,
211        args: datafusion_expr::ReturnFieldArgs,
212    ) -> Result<FieldRef> {
213        Ok(Arc::clone(&args.arg_fields[0]))
214    }
215
216    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
217        make_scalar_function(array_remove_n_inner)(&args.args)
218    }
219
220    fn aliases(&self) -> &[String] {
221        &self.aliases
222    }
223
224    fn documentation(&self) -> Option<&Documentation> {
225        self.doc()
226    }
227}
228
229make_udf_expr_and_func!(
230    ArrayRemoveAll,
231    array_remove_all,
232    array element,
233    "removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
234    array_remove_all_udf
235);
236
237#[user_doc(
238    doc_section(label = "Array Functions"),
239    description = "Removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
240    syntax_example = "array_remove_all(array, element)",
241    sql_example = r#"```sql
242> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
243+--------------------------------------------------+
244| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
245+--------------------------------------------------+
246| [1, 3, 1, 4]                                     |
247+--------------------------------------------------+
248
249> select array_remove_all([1, 2, NULL, 2, 4], 2);
250+-----------------------------------------------------+
251| array_remove_all(List([1,2,NULL,2,4]),Int64(2)) |
252+-----------------------------------------------------+
253| [1, NULL, 4]                                     |
254+-----------------------------------------------------+
255```"#,
256    argument(
257        name = "array",
258        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
259    ),
260    argument(
261        name = "element",
262        description = "Element to be removed from the array."
263    )
264)]
265#[derive(Debug, PartialEq, Eq, Hash)]
266pub struct ArrayRemoveAll {
267    signature: Signature,
268    aliases: Vec<String>,
269}
270
271impl Default for ArrayRemoveAll {
272    fn default() -> Self {
273        Self::new()
274    }
275}
276
277impl ArrayRemoveAll {
278    pub fn new() -> Self {
279        Self {
280            signature: Signature::array_and_element(Volatility::Immutable),
281            aliases: vec!["list_remove_all".to_string()],
282        }
283    }
284}
285
286impl ScalarUDFImpl for ArrayRemoveAll {
287    fn name(&self) -> &str {
288        "array_remove_all"
289    }
290
291    fn signature(&self) -> &Signature {
292        &self.signature
293    }
294
295    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
296        internal_err!("return_field_from_args should be used instead")
297    }
298
299    fn return_field_from_args(
300        &self,
301        args: datafusion_expr::ReturnFieldArgs,
302    ) -> Result<FieldRef> {
303        Ok(Arc::clone(&args.arg_fields[0]))
304    }
305
306    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
307        make_scalar_function(array_remove_all_inner)(&args.args)
308    }
309
310    fn aliases(&self) -> &[String] {
311        &self.aliases
312    }
313
314    fn documentation(&self) -> Option<&Documentation> {
315        self.doc()
316    }
317}
318
319fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
320    let [array, element] = take_function_args("array_remove", args)?;
321
322    let arr_n = vec![1; array.len()];
323    array_remove_internal(array, element, &arr_n)
324}
325
326fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
327    let [array, element, max] = take_function_args("array_remove_n", args)?;
328
329    let arr_n = as_int64_array(max)?.values().to_vec();
330    array_remove_internal(array, element, &arr_n)
331}
332
333fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
334    let [array, element] = take_function_args("array_remove_all", args)?;
335
336    let arr_n = vec![i64::MAX; array.len()];
337    array_remove_internal(array, element, &arr_n)
338}
339
340fn array_remove_internal(
341    array: &ArrayRef,
342    element_array: &ArrayRef,
343    arr_n: &[i64],
344) -> Result<ArrayRef> {
345    match array.data_type() {
346        DataType::List(_) => {
347            let list_array = array.as_list::<i32>();
348            general_remove::<i32>(list_array, element_array, arr_n)
349        }
350        DataType::LargeList(_) => {
351            let list_array = array.as_list::<i64>();
352            general_remove::<i64>(list_array, element_array, arr_n)
353        }
354        array_type => {
355            exec_err!("array_remove_all does not support type '{array_type}'.")
356        }
357    }
358}
359
360/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
361/// of `element_array[i]`.
362///
363/// The type of each **element** in `list_array` must be the same as the type of
364/// `element_array`. This function also handles nested arrays
365/// ([`arrow::array::ListArray`] of [`arrow::array::ListArray`]s)
366///
367/// For example, when called to remove a list array (where each element is a
368/// list of int32s, the second argument are int32 arrays, and the
369/// third argument is the number of occurrences to remove
370///
371/// ```text
372/// general_remove(
373///   [1, 2, 3, 2], 2, 1    ==> [1, 3, 2]   (only the first 2 is removed)
374///   [4, 5, 6, 5], 5, 2    ==> [4, 6]  (both 5s are removed)
375/// )
376/// ```
377fn general_remove<OffsetSize: OffsetSizeTrait>(
378    list_array: &GenericListArray<OffsetSize>,
379    element_array: &ArrayRef,
380    arr_n: &[i64],
381) -> Result<ArrayRef> {
382    let list_field = match list_array.data_type() {
383        DataType::List(field) | DataType::LargeList(field) => field,
384        _ => {
385            return exec_err!(
386                "Expected List or LargeList data type, got {:?}",
387                list_array.data_type()
388            );
389        }
390    };
391    let original_data = list_array.values().to_data();
392    // Build up the offsets for the final output array
393    let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
394    offsets.push(OffsetSize::zero());
395
396    let mut mutable = MutableArrayData::with_capacities(
397        vec![&original_data],
398        false,
399        Capacities::Array(original_data.len()),
400    );
401
402    // Pre-compute combined null bitmap
403    let nulls = NullBuffer::union(list_array.nulls(), element_array.nulls());
404
405    for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
406        if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) {
407            offsets.push(offsets[row_index]);
408            continue;
409        }
410
411        let start = offset_window[0].to_usize().unwrap();
412        let end = offset_window[1].to_usize().unwrap();
413        // n is the number of elements to remove in this row
414        let n = arr_n[row_index];
415
416        // compare each element in the list, `false` means the element matches and should be removed
417        let eq_array = utils::compare_element_to_list(
418            &list_array.value(row_index),
419            element_array,
420            row_index,
421            false,
422        )?;
423
424        let num_to_remove = eq_array.false_count();
425
426        // Fast path: no elements to remove, copy entire row
427        if num_to_remove == 0 {
428            mutable.extend(0, start, end);
429            offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start));
430            continue;
431        }
432
433        // Remove at most `n` matching elements
434        let max_removals = n.min(num_to_remove as i64);
435        let mut removed = 0i64;
436        let mut copied = 0usize;
437        // marks the beginning of a range of elements pending to be copied.
438        let mut pending_batch_to_retain: Option<usize> = None;
439        for (i, keep) in eq_array.iter().enumerate() {
440            if keep == Some(false) && removed < max_removals {
441                // Flush pending batch before skipping this element
442                if let Some(bs) = pending_batch_to_retain {
443                    mutable.extend(0, start + bs, start + i);
444                    copied += i - bs;
445                    pending_batch_to_retain = None;
446                }
447                removed += 1;
448            } else if pending_batch_to_retain.is_none() {
449                pending_batch_to_retain = Some(i);
450            }
451        }
452
453        // Flush remaining batch
454        if let Some(bs) = pending_batch_to_retain {
455            mutable.extend(0, start + bs, start + eq_array.len());
456            copied += eq_array.len() - bs;
457        }
458
459        offsets.push(offsets[row_index] + OffsetSize::usize_as(copied));
460    }
461
462    let new_values = make_array(mutable.freeze());
463    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
464        Arc::clone(list_field),
465        OffsetBuffer::new(offsets.into()),
466        new_values,
467        nulls,
468    )?))
469}
470
471#[cfg(test)]
472mod tests {
473    use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
474    use arrow::array::{
475        Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
476    };
477    use arrow::datatypes::{DataType, Field, Int32Type};
478    use datafusion_common::ScalarValue;
479    use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
480    use datafusion_expr_common::columnar_value::ColumnarValue;
481    use std::ops::Deref;
482    use std::sync::Arc;
483
484    #[test]
485    fn test_array_remove_nullability() {
486        for nullability in [true, false] {
487            for item_nullability in [true, false] {
488                let input_field = Arc::new(Field::new(
489                    "num",
490                    DataType::new_list(DataType::Int32, item_nullability),
491                    nullability,
492                ));
493                let args_fields = vec![
494                    Arc::clone(&input_field),
495                    Arc::new(Field::new("a", DataType::Int32, false)),
496                ];
497                let scalar_args = vec![None, Some(&ScalarValue::Int32(Some(1)))];
498
499                let result = ArrayRemove::new()
500                    .return_field_from_args(ReturnFieldArgs {
501                        arg_fields: &args_fields,
502                        scalar_arguments: &scalar_args,
503                    })
504                    .unwrap();
505
506                assert_eq!(result, input_field);
507            }
508        }
509    }
510
511    #[test]
512    fn test_array_remove_n_nullability() {
513        for nullability in [true, false] {
514            for item_nullability in [true, false] {
515                let input_field = Arc::new(Field::new(
516                    "num",
517                    DataType::new_list(DataType::Int32, item_nullability),
518                    nullability,
519                ));
520                let args_fields = vec![
521                    Arc::clone(&input_field),
522                    Arc::new(Field::new("a", DataType::Int32, false)),
523                    Arc::new(Field::new("b", DataType::Int64, false)),
524                ];
525                let scalar_args = vec![
526                    None,
527                    Some(&ScalarValue::Int32(Some(1))),
528                    Some(&ScalarValue::Int64(Some(1))),
529                ];
530
531                let result = ArrayRemoveN::new()
532                    .return_field_from_args(ReturnFieldArgs {
533                        arg_fields: &args_fields,
534                        scalar_arguments: &scalar_args,
535                    })
536                    .unwrap();
537
538                assert_eq!(result, input_field);
539            }
540        }
541    }
542
543    #[test]
544    fn test_array_remove_all_nullability() {
545        for nullability in [true, false] {
546            for item_nullability in [true, false] {
547                let input_field = Arc::new(Field::new(
548                    "num",
549                    DataType::new_list(DataType::Int32, item_nullability),
550                    nullability,
551                ));
552                let result = ArrayRemoveAll::new()
553                    .return_field_from_args(ReturnFieldArgs {
554                        arg_fields: &[Arc::clone(&input_field)],
555                        scalar_arguments: &[None],
556                    })
557                    .unwrap();
558
559                assert_eq!(result, input_field);
560            }
561        }
562    }
563
564    fn ensure_field_nullability<O: OffsetSizeTrait>(
565        field_nullable: bool,
566        list: GenericListArray<O>,
567    ) -> GenericListArray<O> {
568        let (field, offsets, values, nulls) = list.into_parts();
569
570        if field.is_nullable() == field_nullable {
571            return GenericListArray::new(field, offsets, values, nulls);
572        }
573        if !field_nullable {
574            assert_eq!(nulls, None);
575        }
576
577        let field = Arc::new(field.deref().clone().with_nullable(field_nullable));
578
579        GenericListArray::new(field, offsets, values, nulls)
580    }
581
582    #[test]
583    fn test_array_remove_non_nullable() {
584        let input_list = Arc::new(ensure_field_nullability(
585            false,
586            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
587                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
588                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
589            ]),
590        ));
591        let expected_list = ensure_field_nullability(
592            false,
593            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
594                Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
595                Some(([42, 55, 63, 2]).iter().copied().map(Some)),
596            ]),
597        );
598
599        let element_to_remove = ScalarValue::Int32(Some(2));
600
601        assert_array_remove(input_list, expected_list, element_to_remove);
602    }
603
604    #[test]
605    fn test_array_remove_nullable() {
606        let input_list = Arc::new(ensure_field_nullability(
607            true,
608            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
609                Some(vec![
610                    Some(1),
611                    Some(2),
612                    Some(2),
613                    Some(3),
614                    None,
615                    Some(1),
616                    Some(4),
617                ]),
618                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
619            ]),
620        ));
621        let expected_list = ensure_field_nullability(
622            true,
623            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
624                Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
625                Some(vec![Some(42), None, Some(63), Some(2)]),
626            ]),
627        );
628
629        let element_to_remove = ScalarValue::Int32(Some(2));
630
631        assert_array_remove(input_list, expected_list, element_to_remove);
632    }
633
634    fn assert_array_remove(
635        input_list: ArrayRef,
636        expected_list: GenericListArray<i32>,
637        element_to_remove: ScalarValue,
638    ) {
639        assert_eq!(input_list.data_type(), expected_list.data_type());
640        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
641        let input_list_len = input_list.len();
642        let input_list_data_type = input_list.data_type().clone();
643
644        let udf = ArrayRemove::new();
645        let args_fields = vec![
646            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
647            Arc::new(Field::new(
648                "el",
649                element_to_remove.data_type(),
650                element_to_remove.is_null(),
651            )),
652        ];
653        let scalar_args = vec![None, Some(&element_to_remove)];
654
655        let return_field = udf
656            .return_field_from_args(ReturnFieldArgs {
657                arg_fields: &args_fields,
658                scalar_arguments: &scalar_args,
659            })
660            .unwrap();
661
662        let result = udf
663            .invoke_with_args(ScalarFunctionArgs {
664                args: vec![
665                    ColumnarValue::Array(input_list),
666                    ColumnarValue::Scalar(element_to_remove),
667                ],
668                arg_fields: args_fields,
669                number_rows: input_list_len,
670                return_field,
671                config_options: Arc::new(Default::default()),
672            })
673            .unwrap();
674
675        assert_eq!(result.data_type(), input_list_data_type);
676        match result {
677            ColumnarValue::Array(array) => {
678                let result_list = array.as_list::<i32>();
679                assert_eq!(result_list, &expected_list);
680            }
681            _ => panic!("Expected ColumnarValue::Array"),
682        }
683    }
684
685    #[test]
686    fn test_array_remove_n_non_nullable() {
687        let input_list = Arc::new(ensure_field_nullability(
688            false,
689            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
690                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
691                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
692            ]),
693        ));
694        let expected_list = ensure_field_nullability(
695            false,
696            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
697                Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)),
698                Some(([42, 55, 63]).iter().copied().map(Some)),
699            ]),
700        );
701
702        let element_to_remove = ScalarValue::Int32(Some(2));
703
704        assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
705    }
706
707    #[test]
708    fn test_array_remove_n_nullable() {
709        let input_list = Arc::new(ensure_field_nullability(
710            true,
711            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
712                Some(vec![
713                    Some(1),
714                    Some(2),
715                    Some(2),
716                    Some(3),
717                    None,
718                    Some(1),
719                    Some(4),
720                ]),
721                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
722            ]),
723        ));
724        let expected_list = ensure_field_nullability(
725            true,
726            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
727                Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
728                Some(vec![Some(42), None, Some(63)]),
729            ]),
730        );
731
732        let element_to_remove = ScalarValue::Int32(Some(2));
733
734        assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
735    }
736
737    fn assert_array_remove_n(
738        input_list: ArrayRef,
739        expected_list: GenericListArray<i32>,
740        element_to_remove: ScalarValue,
741        n: i64,
742    ) {
743        assert_eq!(input_list.data_type(), expected_list.data_type());
744        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
745        let input_list_len = input_list.len();
746        let input_list_data_type = input_list.data_type().clone();
747
748        let count_scalar = ScalarValue::Int64(Some(n));
749
750        let udf = ArrayRemoveN::new();
751        let args_fields = vec![
752            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
753            Arc::new(Field::new(
754                "el",
755                element_to_remove.data_type(),
756                element_to_remove.is_null(),
757            )),
758            Arc::new(Field::new("count", DataType::Int64, false)),
759        ];
760        let scalar_args = vec![None, Some(&element_to_remove), Some(&count_scalar)];
761
762        let return_field = udf
763            .return_field_from_args(ReturnFieldArgs {
764                arg_fields: &args_fields,
765                scalar_arguments: &scalar_args,
766            })
767            .unwrap();
768
769        let result = udf
770            .invoke_with_args(ScalarFunctionArgs {
771                args: vec![
772                    ColumnarValue::Array(input_list),
773                    ColumnarValue::Scalar(element_to_remove),
774                    ColumnarValue::Scalar(count_scalar),
775                ],
776                arg_fields: args_fields,
777                number_rows: input_list_len,
778                return_field,
779                config_options: Arc::new(Default::default()),
780            })
781            .unwrap();
782
783        assert_eq!(result.data_type(), input_list_data_type);
784        match result {
785            ColumnarValue::Array(array) => {
786                let result_list = array.as_list::<i32>();
787                assert_eq!(result_list, &expected_list);
788            }
789            _ => panic!("Expected ColumnarValue::Array"),
790        }
791    }
792
793    #[test]
794    fn test_array_remove_all_non_nullable() {
795        let input_list = Arc::new(ensure_field_nullability(
796            false,
797            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
798                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
799                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
800            ]),
801        ));
802        let expected_list = ensure_field_nullability(
803            false,
804            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
805                Some(([1, 3, 1, 4]).iter().copied().map(Some)),
806                Some(([42, 55, 63]).iter().copied().map(Some)),
807            ]),
808        );
809
810        let element_to_remove = ScalarValue::Int32(Some(2));
811
812        assert_array_remove_all(input_list, expected_list, element_to_remove);
813    }
814
815    #[test]
816    fn test_array_remove_all_nullable() {
817        let input_list = Arc::new(ensure_field_nullability(
818            true,
819            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
820                Some(vec![
821                    Some(1),
822                    Some(2),
823                    Some(2),
824                    Some(3),
825                    None,
826                    Some(1),
827                    Some(4),
828                ]),
829                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
830            ]),
831        ));
832        let expected_list = ensure_field_nullability(
833            true,
834            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
835                Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
836                Some(vec![Some(42), None, Some(63)]),
837            ]),
838        );
839
840        let element_to_remove = ScalarValue::Int32(Some(2));
841
842        assert_array_remove_all(input_list, expected_list, element_to_remove);
843    }
844
845    fn assert_array_remove_all(
846        input_list: ArrayRef,
847        expected_list: GenericListArray<i32>,
848        element_to_remove: ScalarValue,
849    ) {
850        assert_eq!(input_list.data_type(), expected_list.data_type());
851        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
852        let input_list_len = input_list.len();
853        let input_list_data_type = input_list.data_type().clone();
854
855        let udf = ArrayRemoveAll::new();
856        let args_fields = vec![
857            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
858            Arc::new(Field::new(
859                "el",
860                element_to_remove.data_type(),
861                element_to_remove.is_null(),
862            )),
863        ];
864        let scalar_args = vec![None, Some(&element_to_remove)];
865
866        let return_field = udf
867            .return_field_from_args(ReturnFieldArgs {
868                arg_fields: &args_fields,
869                scalar_arguments: &scalar_args,
870            })
871            .unwrap();
872
873        let result = udf
874            .invoke_with_args(ScalarFunctionArgs {
875                args: vec![
876                    ColumnarValue::Array(input_list),
877                    ColumnarValue::Scalar(element_to_remove),
878                ],
879                arg_fields: args_fields,
880                number_rows: input_list_len,
881                return_field,
882                config_options: Arc::new(Default::default()),
883            })
884            .unwrap();
885
886        assert_eq!(result.data_type(), input_list_data_type);
887        match result {
888            ColumnarValue::Array(array) => {
889                let result_list = array.as_list::<i32>();
890                assert_eq!(result_list, &expected_list);
891            }
892            _ => panic!("Expected ColumnarValue::Array"),
893        }
894    }
895}