1use crate::utils;
21use crate::utils::make_scalar_function;
22use arrow::array::{
23 Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder,
24 OffsetSizeTrait, cast::AsArray, make_array,
25};
26use arrow::buffer::OffsetBuffer;
27use arrow::datatypes::{DataType, FieldRef};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
31use datafusion_expr::{
32 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33 ScalarUDFImpl, Signature, TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::any::Any;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40 ArrayRemove,
41 array_remove,
42 array element,
43 "removes the first element from the array equal to the given value.",
44 array_remove_udf
45);
46
47#[user_doc(
48 doc_section(label = "Array Functions"),
49 description = "Removes the first element from the array equal to the given value.",
50 syntax_example = "array_remove(array, element)",
51 sql_example = r#"```sql
52> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
53+----------------------------------------------+
54| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
55+----------------------------------------------+
56| [1, 2, 3, 2, 1, 4] |
57+----------------------------------------------+
58```"#,
59 argument(
60 name = "array",
61 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62 ),
63 argument(
64 name = "element",
65 description = "Element to be removed from the array."
66 )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct ArrayRemove {
70 signature: Signature,
71 aliases: Vec<String>,
72}
73
74impl Default for ArrayRemove {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl ArrayRemove {
81 pub fn new() -> Self {
82 Self {
83 signature: Signature::array_and_element(Volatility::Immutable),
84 aliases: vec!["list_remove".to_string()],
85 }
86 }
87}
88
89impl ScalarUDFImpl for ArrayRemove {
90 fn as_any(&self) -> &dyn Any {
91 self
92 }
93
94 fn name(&self) -> &str {
95 "array_remove"
96 }
97
98 fn signature(&self) -> &Signature {
99 &self.signature
100 }
101
102 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
103 internal_err!("return_field_from_args should be used instead")
104 }
105
106 fn return_field_from_args(
107 &self,
108 args: datafusion_expr::ReturnFieldArgs,
109 ) -> Result<FieldRef> {
110 Ok(Arc::clone(&args.arg_fields[0]))
111 }
112
113 fn invoke_with_args(
114 &self,
115 args: datafusion_expr::ScalarFunctionArgs,
116 ) -> Result<ColumnarValue> {
117 make_scalar_function(array_remove_inner)(&args.args)
118 }
119
120 fn aliases(&self) -> &[String] {
121 &self.aliases
122 }
123
124 fn documentation(&self) -> Option<&Documentation> {
125 self.doc()
126 }
127}
128
129make_udf_expr_and_func!(
130 ArrayRemoveN,
131 array_remove_n,
132 array element max,
133 "removes the first `max` elements from the array equal to the given value.",
134 array_remove_n_udf
135);
136
137#[user_doc(
138 doc_section(label = "Array Functions"),
139 description = "Removes the first `max` elements from the array equal to the given value.",
140 syntax_example = "array_remove_n(array, element, max))",
141 sql_example = r#"```sql
142> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
143+---------------------------------------------------------+
144| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
145+---------------------------------------------------------+
146| [1, 3, 2, 1, 4] |
147+---------------------------------------------------------+
148```"#,
149 argument(
150 name = "array",
151 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
152 ),
153 argument(
154 name = "element",
155 description = "Element to be removed from the array."
156 ),
157 argument(name = "max", description = "Number of first occurrences to remove.")
158)]
159#[derive(Debug, PartialEq, Eq, Hash)]
160pub(super) struct ArrayRemoveN {
161 signature: Signature,
162 aliases: Vec<String>,
163}
164
165impl ArrayRemoveN {
166 pub fn new() -> Self {
167 Self {
168 signature: Signature::new(
169 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
170 arguments: vec![
171 ArrayFunctionArgument::Array,
172 ArrayFunctionArgument::Element,
173 ArrayFunctionArgument::Index,
174 ],
175 array_coercion: Some(ListCoercion::FixedSizedListToList),
176 }),
177 Volatility::Immutable,
178 ),
179 aliases: vec!["list_remove_n".to_string()],
180 }
181 }
182}
183
184impl ScalarUDFImpl for ArrayRemoveN {
185 fn as_any(&self) -> &dyn Any {
186 self
187 }
188
189 fn name(&self) -> &str {
190 "array_remove_n"
191 }
192
193 fn signature(&self) -> &Signature {
194 &self.signature
195 }
196
197 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
198 internal_err!("return_field_from_args should be used instead")
199 }
200
201 fn return_field_from_args(
202 &self,
203 args: datafusion_expr::ReturnFieldArgs,
204 ) -> Result<FieldRef> {
205 Ok(Arc::clone(&args.arg_fields[0]))
206 }
207
208 fn invoke_with_args(
209 &self,
210 args: datafusion_expr::ScalarFunctionArgs,
211 ) -> Result<ColumnarValue> {
212 make_scalar_function(array_remove_n_inner)(&args.args)
213 }
214
215 fn aliases(&self) -> &[String] {
216 &self.aliases
217 }
218
219 fn documentation(&self) -> Option<&Documentation> {
220 self.doc()
221 }
222}
223
224make_udf_expr_and_func!(
225 ArrayRemoveAll,
226 array_remove_all,
227 array element,
228 "removes all elements from the array equal to the given value.",
229 array_remove_all_udf
230);
231
232#[user_doc(
233 doc_section(label = "Array Functions"),
234 description = "Removes all elements from the array equal to the given value.",
235 syntax_example = "array_remove_all(array, element)",
236 sql_example = r#"```sql
237> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
238+--------------------------------------------------+
239| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
240+--------------------------------------------------+
241| [1, 3, 1, 4] |
242+--------------------------------------------------+
243```"#,
244 argument(
245 name = "array",
246 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
247 ),
248 argument(
249 name = "element",
250 description = "Element to be removed from the array."
251 )
252)]
253#[derive(Debug, PartialEq, Eq, Hash)]
254pub(super) struct ArrayRemoveAll {
255 signature: Signature,
256 aliases: Vec<String>,
257}
258
259impl ArrayRemoveAll {
260 pub fn new() -> Self {
261 Self {
262 signature: Signature::array_and_element(Volatility::Immutable),
263 aliases: vec!["list_remove_all".to_string()],
264 }
265 }
266}
267
268impl ScalarUDFImpl for ArrayRemoveAll {
269 fn as_any(&self) -> &dyn Any {
270 self
271 }
272
273 fn name(&self) -> &str {
274 "array_remove_all"
275 }
276
277 fn signature(&self) -> &Signature {
278 &self.signature
279 }
280
281 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
282 internal_err!("return_field_from_args should be used instead")
283 }
284
285 fn return_field_from_args(
286 &self,
287 args: datafusion_expr::ReturnFieldArgs,
288 ) -> Result<FieldRef> {
289 Ok(Arc::clone(&args.arg_fields[0]))
290 }
291
292 fn invoke_with_args(
293 &self,
294 args: datafusion_expr::ScalarFunctionArgs,
295 ) -> Result<ColumnarValue> {
296 make_scalar_function(array_remove_all_inner)(&args.args)
297 }
298
299 fn aliases(&self) -> &[String] {
300 &self.aliases
301 }
302
303 fn documentation(&self) -> Option<&Documentation> {
304 self.doc()
305 }
306}
307
308fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
309 let [array, element] = take_function_args("array_remove", args)?;
310
311 let arr_n = vec![1; array.len()];
312 array_remove_internal(array, element, &arr_n)
313}
314
315fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
316 let [array, element, max] = take_function_args("array_remove_n", args)?;
317
318 let arr_n = as_int64_array(max)?.values().to_vec();
319 array_remove_internal(array, element, &arr_n)
320}
321
322fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
323 let [array, element] = take_function_args("array_remove_all", args)?;
324
325 let arr_n = vec![i64::MAX; array.len()];
326 array_remove_internal(array, element, &arr_n)
327}
328
329fn array_remove_internal(
330 array: &ArrayRef,
331 element_array: &ArrayRef,
332 arr_n: &[i64],
333) -> Result<ArrayRef> {
334 match array.data_type() {
335 DataType::List(_) => {
336 let list_array = array.as_list::<i32>();
337 general_remove::<i32>(list_array, element_array, arr_n)
338 }
339 DataType::LargeList(_) => {
340 let list_array = array.as_list::<i64>();
341 general_remove::<i64>(list_array, element_array, arr_n)
342 }
343 array_type => {
344 exec_err!("array_remove_all does not support type '{array_type}'.")
345 }
346 }
347}
348
349fn general_remove<OffsetSize: OffsetSizeTrait>(
367 list_array: &GenericListArray<OffsetSize>,
368 element_array: &ArrayRef,
369 arr_n: &[i64],
370) -> Result<ArrayRef> {
371 let list_field = match list_array.data_type() {
372 DataType::List(field) | DataType::LargeList(field) => field,
373 _ => {
374 return exec_err!(
375 "Expected List or LargeList data type, got {:?}",
376 list_array.data_type()
377 );
378 }
379 };
380 let original_data = list_array.values().to_data();
381 let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
383 offsets.push(OffsetSize::zero());
384
385 let mut mutable = MutableArrayData::with_capacities(
386 vec![&original_data],
387 false,
388 Capacities::Array(original_data.len()),
389 );
390 let mut valid = NullBufferBuilder::new(list_array.len());
391
392 for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
393 if list_array.is_null(row_index) || element_array.is_null(row_index) {
394 offsets.push(offsets[row_index]);
395 valid.append_null();
396 continue;
397 }
398
399 let start = offset_window[0].to_usize().unwrap();
400 let end = offset_window[1].to_usize().unwrap();
401 let n = arr_n[row_index];
403
404 let eq_array = utils::compare_element_to_list(
406 &list_array.value(row_index),
407 element_array,
408 row_index,
409 false,
410 )?;
411
412 let num_to_remove = eq_array.false_count();
413
414 if num_to_remove == 0 {
416 mutable.extend(0, start, end);
417 offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start));
418 valid.append_non_null();
419 continue;
420 }
421
422 let max_removals = n.min(num_to_remove as i64);
424 let mut removed = 0i64;
425 let mut copied = 0usize;
426 let mut pending_batch_to_retain: Option<usize> = None;
428 for (i, keep) in eq_array.iter().enumerate() {
429 if keep == Some(false) && removed < max_removals {
430 if let Some(bs) = pending_batch_to_retain {
432 mutable.extend(0, start + bs, start + i);
433 copied += i - bs;
434 pending_batch_to_retain = None;
435 }
436 removed += 1;
437 } else if pending_batch_to_retain.is_none() {
438 pending_batch_to_retain = Some(i);
439 }
440 }
441
442 if let Some(bs) = pending_batch_to_retain {
444 mutable.extend(0, start + bs, start + eq_array.len());
445 copied += eq_array.len() - bs;
446 }
447
448 offsets.push(offsets[row_index] + OffsetSize::usize_as(copied));
449 valid.append_non_null();
450 }
451
452 let new_values = make_array(mutable.freeze());
453 Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
454 Arc::clone(list_field),
455 OffsetBuffer::new(offsets.into()),
456 new_values,
457 valid.finish(),
458 )?))
459}
460
461#[cfg(test)]
462mod tests {
463 use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
464 use arrow::array::{
465 Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
466 };
467 use arrow::datatypes::{DataType, Field, Int32Type};
468 use datafusion_common::ScalarValue;
469 use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
470 use datafusion_expr_common::columnar_value::ColumnarValue;
471 use std::ops::Deref;
472 use std::sync::Arc;
473
474 #[test]
475 fn test_array_remove_nullability() {
476 for nullability in [true, false] {
477 for item_nullability in [true, false] {
478 let input_field = Arc::new(Field::new(
479 "num",
480 DataType::new_list(DataType::Int32, item_nullability),
481 nullability,
482 ));
483 let args_fields = vec![
484 Arc::clone(&input_field),
485 Arc::new(Field::new("a", DataType::Int32, false)),
486 ];
487 let scalar_args = vec![None, Some(&ScalarValue::Int32(Some(1)))];
488
489 let result = ArrayRemove::new()
490 .return_field_from_args(ReturnFieldArgs {
491 arg_fields: &args_fields,
492 scalar_arguments: &scalar_args,
493 })
494 .unwrap();
495
496 assert_eq!(result, input_field);
497 }
498 }
499 }
500
501 #[test]
502 fn test_array_remove_n_nullability() {
503 for nullability in [true, false] {
504 for item_nullability in [true, false] {
505 let input_field = Arc::new(Field::new(
506 "num",
507 DataType::new_list(DataType::Int32, item_nullability),
508 nullability,
509 ));
510 let args_fields = vec![
511 Arc::clone(&input_field),
512 Arc::new(Field::new("a", DataType::Int32, false)),
513 Arc::new(Field::new("b", DataType::Int64, false)),
514 ];
515 let scalar_args = vec![
516 None,
517 Some(&ScalarValue::Int32(Some(1))),
518 Some(&ScalarValue::Int64(Some(1))),
519 ];
520
521 let result = ArrayRemoveN::new()
522 .return_field_from_args(ReturnFieldArgs {
523 arg_fields: &args_fields,
524 scalar_arguments: &scalar_args,
525 })
526 .unwrap();
527
528 assert_eq!(result, input_field);
529 }
530 }
531 }
532
533 #[test]
534 fn test_array_remove_all_nullability() {
535 for nullability in [true, false] {
536 for item_nullability in [true, false] {
537 let input_field = Arc::new(Field::new(
538 "num",
539 DataType::new_list(DataType::Int32, item_nullability),
540 nullability,
541 ));
542 let result = ArrayRemoveAll::new()
543 .return_field_from_args(ReturnFieldArgs {
544 arg_fields: &[Arc::clone(&input_field)],
545 scalar_arguments: &[None],
546 })
547 .unwrap();
548
549 assert_eq!(result, input_field);
550 }
551 }
552 }
553
554 fn ensure_field_nullability<O: OffsetSizeTrait>(
555 field_nullable: bool,
556 list: GenericListArray<O>,
557 ) -> GenericListArray<O> {
558 let (field, offsets, values, nulls) = list.into_parts();
559
560 if field.is_nullable() == field_nullable {
561 return GenericListArray::new(field, offsets, values, nulls);
562 }
563 if !field_nullable {
564 assert_eq!(nulls, None);
565 }
566
567 let field = Arc::new(field.deref().clone().with_nullable(field_nullable));
568
569 GenericListArray::new(field, offsets, values, nulls)
570 }
571
572 #[test]
573 fn test_array_remove_non_nullable() {
574 let input_list = Arc::new(ensure_field_nullability(
575 false,
576 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
577 Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
578 Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
579 ]),
580 ));
581 let expected_list = ensure_field_nullability(
582 false,
583 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
584 Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
585 Some(([42, 55, 63, 2]).iter().copied().map(Some)),
586 ]),
587 );
588
589 let element_to_remove = ScalarValue::Int32(Some(2));
590
591 assert_array_remove(input_list, expected_list, element_to_remove);
592 }
593
594 #[test]
595 fn test_array_remove_nullable() {
596 let input_list = Arc::new(ensure_field_nullability(
597 true,
598 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
599 Some(vec![
600 Some(1),
601 Some(2),
602 Some(2),
603 Some(3),
604 None,
605 Some(1),
606 Some(4),
607 ]),
608 Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
609 ]),
610 ));
611 let expected_list = ensure_field_nullability(
612 true,
613 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
614 Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
615 Some(vec![Some(42), None, Some(63), Some(2)]),
616 ]),
617 );
618
619 let element_to_remove = ScalarValue::Int32(Some(2));
620
621 assert_array_remove(input_list, expected_list, element_to_remove);
622 }
623
624 fn assert_array_remove(
625 input_list: ArrayRef,
626 expected_list: GenericListArray<i32>,
627 element_to_remove: ScalarValue,
628 ) {
629 assert_eq!(input_list.data_type(), expected_list.data_type());
630 assert_eq!(expected_list.value_type(), element_to_remove.data_type());
631 let input_list_len = input_list.len();
632 let input_list_data_type = input_list.data_type().clone();
633
634 let udf = ArrayRemove::new();
635 let args_fields = vec![
636 Arc::new(Field::new("num", input_list.data_type().clone(), false)),
637 Arc::new(Field::new(
638 "el",
639 element_to_remove.data_type(),
640 element_to_remove.is_null(),
641 )),
642 ];
643 let scalar_args = vec![None, Some(&element_to_remove)];
644
645 let return_field = udf
646 .return_field_from_args(ReturnFieldArgs {
647 arg_fields: &args_fields,
648 scalar_arguments: &scalar_args,
649 })
650 .unwrap();
651
652 let result = udf
653 .invoke_with_args(ScalarFunctionArgs {
654 args: vec![
655 ColumnarValue::Array(input_list),
656 ColumnarValue::Scalar(element_to_remove),
657 ],
658 arg_fields: args_fields,
659 number_rows: input_list_len,
660 return_field,
661 config_options: Arc::new(Default::default()),
662 })
663 .unwrap();
664
665 assert_eq!(result.data_type(), input_list_data_type);
666 match result {
667 ColumnarValue::Array(array) => {
668 let result_list = array.as_list::<i32>();
669 assert_eq!(result_list, &expected_list);
670 }
671 _ => panic!("Expected ColumnarValue::Array"),
672 }
673 }
674
675 #[test]
676 fn test_array_remove_n_non_nullable() {
677 let input_list = Arc::new(ensure_field_nullability(
678 false,
679 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
680 Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
681 Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
682 ]),
683 ));
684 let expected_list = ensure_field_nullability(
685 false,
686 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
687 Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)),
688 Some(([42, 55, 63]).iter().copied().map(Some)),
689 ]),
690 );
691
692 let element_to_remove = ScalarValue::Int32(Some(2));
693
694 assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
695 }
696
697 #[test]
698 fn test_array_remove_n_nullable() {
699 let input_list = Arc::new(ensure_field_nullability(
700 true,
701 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
702 Some(vec![
703 Some(1),
704 Some(2),
705 Some(2),
706 Some(3),
707 None,
708 Some(1),
709 Some(4),
710 ]),
711 Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
712 ]),
713 ));
714 let expected_list = ensure_field_nullability(
715 true,
716 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
717 Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
718 Some(vec![Some(42), None, Some(63)]),
719 ]),
720 );
721
722 let element_to_remove = ScalarValue::Int32(Some(2));
723
724 assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
725 }
726
727 fn assert_array_remove_n(
728 input_list: ArrayRef,
729 expected_list: GenericListArray<i32>,
730 element_to_remove: ScalarValue,
731 n: i64,
732 ) {
733 assert_eq!(input_list.data_type(), expected_list.data_type());
734 assert_eq!(expected_list.value_type(), element_to_remove.data_type());
735 let input_list_len = input_list.len();
736 let input_list_data_type = input_list.data_type().clone();
737
738 let count_scalar = ScalarValue::Int64(Some(n));
739
740 let udf = ArrayRemoveN::new();
741 let args_fields = vec![
742 Arc::new(Field::new("num", input_list.data_type().clone(), false)),
743 Arc::new(Field::new(
744 "el",
745 element_to_remove.data_type(),
746 element_to_remove.is_null(),
747 )),
748 Arc::new(Field::new("count", DataType::Int64, false)),
749 ];
750 let scalar_args = vec![None, Some(&element_to_remove), Some(&count_scalar)];
751
752 let return_field = udf
753 .return_field_from_args(ReturnFieldArgs {
754 arg_fields: &args_fields,
755 scalar_arguments: &scalar_args,
756 })
757 .unwrap();
758
759 let result = udf
760 .invoke_with_args(ScalarFunctionArgs {
761 args: vec![
762 ColumnarValue::Array(input_list),
763 ColumnarValue::Scalar(element_to_remove),
764 ColumnarValue::Scalar(count_scalar),
765 ],
766 arg_fields: args_fields,
767 number_rows: input_list_len,
768 return_field,
769 config_options: Arc::new(Default::default()),
770 })
771 .unwrap();
772
773 assert_eq!(result.data_type(), input_list_data_type);
774 match result {
775 ColumnarValue::Array(array) => {
776 let result_list = array.as_list::<i32>();
777 assert_eq!(result_list, &expected_list);
778 }
779 _ => panic!("Expected ColumnarValue::Array"),
780 }
781 }
782
783 #[test]
784 fn test_array_remove_all_non_nullable() {
785 let input_list = Arc::new(ensure_field_nullability(
786 false,
787 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
788 Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
789 Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
790 ]),
791 ));
792 let expected_list = ensure_field_nullability(
793 false,
794 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
795 Some(([1, 3, 1, 4]).iter().copied().map(Some)),
796 Some(([42, 55, 63]).iter().copied().map(Some)),
797 ]),
798 );
799
800 let element_to_remove = ScalarValue::Int32(Some(2));
801
802 assert_array_remove_all(input_list, expected_list, element_to_remove);
803 }
804
805 #[test]
806 fn test_array_remove_all_nullable() {
807 let input_list = Arc::new(ensure_field_nullability(
808 true,
809 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
810 Some(vec![
811 Some(1),
812 Some(2),
813 Some(2),
814 Some(3),
815 None,
816 Some(1),
817 Some(4),
818 ]),
819 Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
820 ]),
821 ));
822 let expected_list = ensure_field_nullability(
823 true,
824 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
825 Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
826 Some(vec![Some(42), None, Some(63)]),
827 ]),
828 );
829
830 let element_to_remove = ScalarValue::Int32(Some(2));
831
832 assert_array_remove_all(input_list, expected_list, element_to_remove);
833 }
834
835 fn assert_array_remove_all(
836 input_list: ArrayRef,
837 expected_list: GenericListArray<i32>,
838 element_to_remove: ScalarValue,
839 ) {
840 assert_eq!(input_list.data_type(), expected_list.data_type());
841 assert_eq!(expected_list.value_type(), element_to_remove.data_type());
842 let input_list_len = input_list.len();
843 let input_list_data_type = input_list.data_type().clone();
844
845 let udf = ArrayRemoveAll::new();
846 let args_fields = vec![
847 Arc::new(Field::new("num", input_list.data_type().clone(), false)),
848 Arc::new(Field::new(
849 "el",
850 element_to_remove.data_type(),
851 element_to_remove.is_null(),
852 )),
853 ];
854 let scalar_args = vec![None, Some(&element_to_remove)];
855
856 let return_field = udf
857 .return_field_from_args(ReturnFieldArgs {
858 arg_fields: &args_fields,
859 scalar_arguments: &scalar_args,
860 })
861 .unwrap();
862
863 let result = udf
864 .invoke_with_args(ScalarFunctionArgs {
865 args: vec![
866 ColumnarValue::Array(input_list),
867 ColumnarValue::Scalar(element_to_remove),
868 ],
869 arg_fields: args_fields,
870 number_rows: input_list_len,
871 return_field,
872 config_options: Arc::new(Default::default()),
873 })
874 .unwrap();
875
876 assert_eq!(result.data_type(), input_list_data_type);
877 match result {
878 ColumnarValue::Array(array) => {
879 let result_list = array.as_list::<i32>();
880 assert_eq!(result_list, &expected_list);
881 }
882 _ => panic!("Expected ColumnarValue::Array"),
883 }
884 }
885}