1use crate::utils;
21use crate::utils::make_scalar_function;
22use arrow::array::{
23 Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, cast::AsArray,
24 new_empty_array,
25};
26use arrow::buffer::OffsetBuffer;
27use arrow::datatypes::{DataType, FieldRef};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
31use datafusion_expr::{
32 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33 ScalarUDFImpl, Signature, TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::any::Any;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40 ArrayRemove,
41 array_remove,
42 array element,
43 "removes the first element from the array equal to the given value.",
44 array_remove_udf
45);
46
47#[user_doc(
48 doc_section(label = "Array Functions"),
49 description = "Removes the first element from the array equal to the given value.",
50 syntax_example = "array_remove(array, element)",
51 sql_example = r#"```sql
52> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
53+----------------------------------------------+
54| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
55+----------------------------------------------+
56| [1, 2, 3, 2, 1, 4] |
57+----------------------------------------------+
58```"#,
59 argument(
60 name = "array",
61 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62 ),
63 argument(
64 name = "element",
65 description = "Element to be removed from the array."
66 )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct ArrayRemove {
70 signature: Signature,
71 aliases: Vec<String>,
72}
73
74impl Default for ArrayRemove {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl ArrayRemove {
81 pub fn new() -> Self {
82 Self {
83 signature: Signature::array_and_element(Volatility::Immutable),
84 aliases: vec!["list_remove".to_string()],
85 }
86 }
87}
88
89impl ScalarUDFImpl for ArrayRemove {
90 fn as_any(&self) -> &dyn Any {
91 self
92 }
93
94 fn name(&self) -> &str {
95 "array_remove"
96 }
97
98 fn signature(&self) -> &Signature {
99 &self.signature
100 }
101
102 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
103 internal_err!("return_field_from_args should be used instead")
104 }
105
106 fn return_field_from_args(
107 &self,
108 args: datafusion_expr::ReturnFieldArgs,
109 ) -> Result<FieldRef> {
110 Ok(Arc::clone(&args.arg_fields[0]))
111 }
112
113 fn invoke_with_args(
114 &self,
115 args: datafusion_expr::ScalarFunctionArgs,
116 ) -> Result<ColumnarValue> {
117 make_scalar_function(array_remove_inner)(&args.args)
118 }
119
120 fn aliases(&self) -> &[String] {
121 &self.aliases
122 }
123
124 fn documentation(&self) -> Option<&Documentation> {
125 self.doc()
126 }
127}
128
129make_udf_expr_and_func!(
130 ArrayRemoveN,
131 array_remove_n,
132 array element max,
133 "removes the first `max` elements from the array equal to the given value.",
134 array_remove_n_udf
135);
136
137#[user_doc(
138 doc_section(label = "Array Functions"),
139 description = "Removes the first `max` elements from the array equal to the given value.",
140 syntax_example = "array_remove_n(array, element, max))",
141 sql_example = r#"```sql
142> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
143+---------------------------------------------------------+
144| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
145+---------------------------------------------------------+
146| [1, 3, 2, 1, 4] |
147+---------------------------------------------------------+
148```"#,
149 argument(
150 name = "array",
151 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
152 ),
153 argument(
154 name = "element",
155 description = "Element to be removed from the array."
156 ),
157 argument(name = "max", description = "Number of first occurrences to remove.")
158)]
159#[derive(Debug, PartialEq, Eq, Hash)]
160pub(super) struct ArrayRemoveN {
161 signature: Signature,
162 aliases: Vec<String>,
163}
164
165impl ArrayRemoveN {
166 pub fn new() -> Self {
167 Self {
168 signature: Signature::new(
169 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
170 arguments: vec![
171 ArrayFunctionArgument::Array,
172 ArrayFunctionArgument::Element,
173 ArrayFunctionArgument::Index,
174 ],
175 array_coercion: Some(ListCoercion::FixedSizedListToList),
176 }),
177 Volatility::Immutable,
178 ),
179 aliases: vec!["list_remove_n".to_string()],
180 }
181 }
182}
183
184impl ScalarUDFImpl for ArrayRemoveN {
185 fn as_any(&self) -> &dyn Any {
186 self
187 }
188
189 fn name(&self) -> &str {
190 "array_remove_n"
191 }
192
193 fn signature(&self) -> &Signature {
194 &self.signature
195 }
196
197 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
198 internal_err!("return_field_from_args should be used instead")
199 }
200
201 fn return_field_from_args(
202 &self,
203 args: datafusion_expr::ReturnFieldArgs,
204 ) -> Result<FieldRef> {
205 Ok(Arc::clone(&args.arg_fields[0]))
206 }
207
208 fn invoke_with_args(
209 &self,
210 args: datafusion_expr::ScalarFunctionArgs,
211 ) -> Result<ColumnarValue> {
212 make_scalar_function(array_remove_n_inner)(&args.args)
213 }
214
215 fn aliases(&self) -> &[String] {
216 &self.aliases
217 }
218
219 fn documentation(&self) -> Option<&Documentation> {
220 self.doc()
221 }
222}
223
224make_udf_expr_and_func!(
225 ArrayRemoveAll,
226 array_remove_all,
227 array element,
228 "removes all elements from the array equal to the given value.",
229 array_remove_all_udf
230);
231
232#[user_doc(
233 doc_section(label = "Array Functions"),
234 description = "Removes all elements from the array equal to the given value.",
235 syntax_example = "array_remove_all(array, element)",
236 sql_example = r#"```sql
237> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
238+--------------------------------------------------+
239| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
240+--------------------------------------------------+
241| [1, 3, 1, 4] |
242+--------------------------------------------------+
243```"#,
244 argument(
245 name = "array",
246 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
247 ),
248 argument(
249 name = "element",
250 description = "Element to be removed from the array."
251 )
252)]
253#[derive(Debug, PartialEq, Eq, Hash)]
254pub(super) struct ArrayRemoveAll {
255 signature: Signature,
256 aliases: Vec<String>,
257}
258
259impl ArrayRemoveAll {
260 pub fn new() -> Self {
261 Self {
262 signature: Signature::array_and_element(Volatility::Immutable),
263 aliases: vec!["list_remove_all".to_string()],
264 }
265 }
266}
267
268impl ScalarUDFImpl for ArrayRemoveAll {
269 fn as_any(&self) -> &dyn Any {
270 self
271 }
272
273 fn name(&self) -> &str {
274 "array_remove_all"
275 }
276
277 fn signature(&self) -> &Signature {
278 &self.signature
279 }
280
281 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
282 internal_err!("return_field_from_args should be used instead")
283 }
284
285 fn return_field_from_args(
286 &self,
287 args: datafusion_expr::ReturnFieldArgs,
288 ) -> Result<FieldRef> {
289 Ok(Arc::clone(&args.arg_fields[0]))
290 }
291
292 fn invoke_with_args(
293 &self,
294 args: datafusion_expr::ScalarFunctionArgs,
295 ) -> Result<ColumnarValue> {
296 make_scalar_function(array_remove_all_inner)(&args.args)
297 }
298
299 fn aliases(&self) -> &[String] {
300 &self.aliases
301 }
302
303 fn documentation(&self) -> Option<&Documentation> {
304 self.doc()
305 }
306}
307
308fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
309 let [array, element] = take_function_args("array_remove", args)?;
310
311 let arr_n = vec![1; array.len()];
312 array_remove_internal(array, element, &arr_n)
313}
314
315fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
316 let [array, element, max] = take_function_args("array_remove_n", args)?;
317
318 let arr_n = as_int64_array(max)?.values().to_vec();
319 array_remove_internal(array, element, &arr_n)
320}
321
322fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
323 let [array, element] = take_function_args("array_remove_all", args)?;
324
325 let arr_n = vec![i64::MAX; array.len()];
326 array_remove_internal(array, element, &arr_n)
327}
328
329fn array_remove_internal(
330 array: &ArrayRef,
331 element_array: &ArrayRef,
332 arr_n: &[i64],
333) -> Result<ArrayRef> {
334 match array.data_type() {
335 DataType::List(_) => {
336 let list_array = array.as_list::<i32>();
337 general_remove::<i32>(list_array, element_array, arr_n)
338 }
339 DataType::LargeList(_) => {
340 let list_array = array.as_list::<i64>();
341 general_remove::<i64>(list_array, element_array, arr_n)
342 }
343 array_type => {
344 exec_err!("array_remove_all does not support type '{array_type}'.")
345 }
346 }
347}
348
349fn general_remove<OffsetSize: OffsetSizeTrait>(
367 list_array: &GenericListArray<OffsetSize>,
368 element_array: &ArrayRef,
369 arr_n: &[i64],
370) -> Result<ArrayRef> {
371 let list_field = match list_array.data_type() {
372 DataType::List(field) | DataType::LargeList(field) => field,
373 _ => {
374 return exec_err!(
375 "Expected List or LargeList data type, got {:?}",
376 list_array.data_type()
377 );
378 }
379 };
380 let data_type = list_field.data_type();
381 let mut new_values = vec![];
382 let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
384 offsets.push(OffsetSize::zero());
385
386 for (row_index, (list_array_row, n)) in
388 list_array.iter().zip(arr_n.iter()).enumerate()
389 {
390 match list_array_row {
391 Some(list_array_row) => {
392 let eq_array = utils::compare_element_to_list(
393 &list_array_row,
394 element_array,
395 row_index,
396 false,
397 )?;
398
399 let eq_array = if eq_array.false_count() < *n as usize {
401 eq_array
402 } else {
403 let mut count = 0;
404 eq_array
405 .iter()
406 .map(|e| {
407 if let Some(false) = e {
409 if count < *n {
410 count += 1;
411 e
412 } else {
413 Some(true)
414 }
415 } else {
416 e
417 }
418 })
419 .collect::<BooleanArray>()
420 };
421
422 let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?;
423 offsets.push(
424 offsets[row_index] + OffsetSize::usize_as(filtered_array.len()),
425 );
426 new_values.push(filtered_array);
427 }
428 None => {
429 offsets.push(offsets[row_index]);
431 }
432 }
433 }
434
435 let values = if new_values.is_empty() {
436 new_empty_array(data_type)
437 } else {
438 let new_values = new_values.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
439 arrow::compute::concat(&new_values)?
440 };
441
442 Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
443 Arc::clone(list_field),
444 OffsetBuffer::new(offsets.into()),
445 values,
446 list_array.nulls().cloned(),
447 )?))
448}
449
450#[cfg(test)]
451mod tests {
452 use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
453 use arrow::array::{
454 Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
455 };
456 use arrow::datatypes::{DataType, Field, Int32Type};
457 use datafusion_common::ScalarValue;
458 use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
459 use datafusion_expr_common::columnar_value::ColumnarValue;
460 use std::ops::Deref;
461 use std::sync::Arc;
462
463 #[test]
464 fn test_array_remove_nullability() {
465 for nullability in [true, false] {
466 for item_nullability in [true, false] {
467 let input_field = Arc::new(Field::new(
468 "num",
469 DataType::new_list(DataType::Int32, item_nullability),
470 nullability,
471 ));
472 let args_fields = vec![
473 Arc::clone(&input_field),
474 Arc::new(Field::new("a", DataType::Int32, false)),
475 ];
476 let scalar_args = vec![None, Some(&ScalarValue::Int32(Some(1)))];
477
478 let result = ArrayRemove::new()
479 .return_field_from_args(ReturnFieldArgs {
480 arg_fields: &args_fields,
481 scalar_arguments: &scalar_args,
482 })
483 .unwrap();
484
485 assert_eq!(result, input_field);
486 }
487 }
488 }
489
490 #[test]
491 fn test_array_remove_n_nullability() {
492 for nullability in [true, false] {
493 for item_nullability in [true, false] {
494 let input_field = Arc::new(Field::new(
495 "num",
496 DataType::new_list(DataType::Int32, item_nullability),
497 nullability,
498 ));
499 let args_fields = vec![
500 Arc::clone(&input_field),
501 Arc::new(Field::new("a", DataType::Int32, false)),
502 Arc::new(Field::new("b", DataType::Int64, false)),
503 ];
504 let scalar_args = vec![
505 None,
506 Some(&ScalarValue::Int32(Some(1))),
507 Some(&ScalarValue::Int64(Some(1))),
508 ];
509
510 let result = ArrayRemoveN::new()
511 .return_field_from_args(ReturnFieldArgs {
512 arg_fields: &args_fields,
513 scalar_arguments: &scalar_args,
514 })
515 .unwrap();
516
517 assert_eq!(result, input_field);
518 }
519 }
520 }
521
522 #[test]
523 fn test_array_remove_all_nullability() {
524 for nullability in [true, false] {
525 for item_nullability in [true, false] {
526 let input_field = Arc::new(Field::new(
527 "num",
528 DataType::new_list(DataType::Int32, item_nullability),
529 nullability,
530 ));
531 let result = ArrayRemoveAll::new()
532 .return_field_from_args(ReturnFieldArgs {
533 arg_fields: &[Arc::clone(&input_field)],
534 scalar_arguments: &[None],
535 })
536 .unwrap();
537
538 assert_eq!(result, input_field);
539 }
540 }
541 }
542
543 fn ensure_field_nullability<O: OffsetSizeTrait>(
544 field_nullable: bool,
545 list: GenericListArray<O>,
546 ) -> GenericListArray<O> {
547 let (field, offsets, values, nulls) = list.into_parts();
548
549 if field.is_nullable() == field_nullable {
550 return GenericListArray::new(field, offsets, values, nulls);
551 }
552 if !field_nullable {
553 assert_eq!(nulls, None);
554 }
555
556 let field = Arc::new(field.deref().clone().with_nullable(field_nullable));
557
558 GenericListArray::new(field, offsets, values, nulls)
559 }
560
561 #[test]
562 fn test_array_remove_non_nullable() {
563 let input_list = Arc::new(ensure_field_nullability(
564 false,
565 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
566 Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
567 Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
568 ]),
569 ));
570 let expected_list = ensure_field_nullability(
571 false,
572 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
573 Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
574 Some(([42, 55, 63, 2]).iter().copied().map(Some)),
575 ]),
576 );
577
578 let element_to_remove = ScalarValue::Int32(Some(2));
579
580 assert_array_remove(input_list, expected_list, element_to_remove);
581 }
582
583 #[test]
584 fn test_array_remove_nullable() {
585 let input_list = Arc::new(ensure_field_nullability(
586 true,
587 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
588 Some(vec![
589 Some(1),
590 Some(2),
591 Some(2),
592 Some(3),
593 None,
594 Some(1),
595 Some(4),
596 ]),
597 Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
598 ]),
599 ));
600 let expected_list = ensure_field_nullability(
601 true,
602 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
603 Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
604 Some(vec![Some(42), None, Some(63), Some(2)]),
605 ]),
606 );
607
608 let element_to_remove = ScalarValue::Int32(Some(2));
609
610 assert_array_remove(input_list, expected_list, element_to_remove);
611 }
612
613 fn assert_array_remove(
614 input_list: ArrayRef,
615 expected_list: GenericListArray<i32>,
616 element_to_remove: ScalarValue,
617 ) {
618 assert_eq!(input_list.data_type(), expected_list.data_type());
619 assert_eq!(expected_list.value_type(), element_to_remove.data_type());
620 let input_list_len = input_list.len();
621 let input_list_data_type = input_list.data_type().clone();
622
623 let udf = ArrayRemove::new();
624 let args_fields = vec![
625 Arc::new(Field::new("num", input_list.data_type().clone(), false)),
626 Arc::new(Field::new(
627 "el",
628 element_to_remove.data_type(),
629 element_to_remove.is_null(),
630 )),
631 ];
632 let scalar_args = vec![None, Some(&element_to_remove)];
633
634 let return_field = udf
635 .return_field_from_args(ReturnFieldArgs {
636 arg_fields: &args_fields,
637 scalar_arguments: &scalar_args,
638 })
639 .unwrap();
640
641 let result = udf
642 .invoke_with_args(ScalarFunctionArgs {
643 args: vec![
644 ColumnarValue::Array(input_list),
645 ColumnarValue::Scalar(element_to_remove),
646 ],
647 arg_fields: args_fields,
648 number_rows: input_list_len,
649 return_field,
650 config_options: Arc::new(Default::default()),
651 })
652 .unwrap();
653
654 assert_eq!(result.data_type(), input_list_data_type);
655 match result {
656 ColumnarValue::Array(array) => {
657 let result_list = array.as_list::<i32>();
658 assert_eq!(result_list, &expected_list);
659 }
660 _ => panic!("Expected ColumnarValue::Array"),
661 }
662 }
663
664 #[test]
665 fn test_array_remove_n_non_nullable() {
666 let input_list = Arc::new(ensure_field_nullability(
667 false,
668 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
669 Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
670 Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
671 ]),
672 ));
673 let expected_list = ensure_field_nullability(
674 false,
675 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
676 Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)),
677 Some(([42, 55, 63]).iter().copied().map(Some)),
678 ]),
679 );
680
681 let element_to_remove = ScalarValue::Int32(Some(2));
682
683 assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
684 }
685
686 #[test]
687 fn test_array_remove_n_nullable() {
688 let input_list = Arc::new(ensure_field_nullability(
689 true,
690 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
691 Some(vec![
692 Some(1),
693 Some(2),
694 Some(2),
695 Some(3),
696 None,
697 Some(1),
698 Some(4),
699 ]),
700 Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
701 ]),
702 ));
703 let expected_list = ensure_field_nullability(
704 true,
705 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
706 Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
707 Some(vec![Some(42), None, Some(63)]),
708 ]),
709 );
710
711 let element_to_remove = ScalarValue::Int32(Some(2));
712
713 assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
714 }
715
716 fn assert_array_remove_n(
717 input_list: ArrayRef,
718 expected_list: GenericListArray<i32>,
719 element_to_remove: ScalarValue,
720 n: i64,
721 ) {
722 assert_eq!(input_list.data_type(), expected_list.data_type());
723 assert_eq!(expected_list.value_type(), element_to_remove.data_type());
724 let input_list_len = input_list.len();
725 let input_list_data_type = input_list.data_type().clone();
726
727 let count_scalar = ScalarValue::Int64(Some(n));
728
729 let udf = ArrayRemoveN::new();
730 let args_fields = vec![
731 Arc::new(Field::new("num", input_list.data_type().clone(), false)),
732 Arc::new(Field::new(
733 "el",
734 element_to_remove.data_type(),
735 element_to_remove.is_null(),
736 )),
737 Arc::new(Field::new("count", DataType::Int64, false)),
738 ];
739 let scalar_args = vec![None, Some(&element_to_remove), Some(&count_scalar)];
740
741 let return_field = udf
742 .return_field_from_args(ReturnFieldArgs {
743 arg_fields: &args_fields,
744 scalar_arguments: &scalar_args,
745 })
746 .unwrap();
747
748 let result = udf
749 .invoke_with_args(ScalarFunctionArgs {
750 args: vec![
751 ColumnarValue::Array(input_list),
752 ColumnarValue::Scalar(element_to_remove),
753 ColumnarValue::Scalar(count_scalar),
754 ],
755 arg_fields: args_fields,
756 number_rows: input_list_len,
757 return_field,
758 config_options: Arc::new(Default::default()),
759 })
760 .unwrap();
761
762 assert_eq!(result.data_type(), input_list_data_type);
763 match result {
764 ColumnarValue::Array(array) => {
765 let result_list = array.as_list::<i32>();
766 assert_eq!(result_list, &expected_list);
767 }
768 _ => panic!("Expected ColumnarValue::Array"),
769 }
770 }
771
772 #[test]
773 fn test_array_remove_all_non_nullable() {
774 let input_list = Arc::new(ensure_field_nullability(
775 false,
776 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
777 Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
778 Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
779 ]),
780 ));
781 let expected_list = ensure_field_nullability(
782 false,
783 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
784 Some(([1, 3, 1, 4]).iter().copied().map(Some)),
785 Some(([42, 55, 63]).iter().copied().map(Some)),
786 ]),
787 );
788
789 let element_to_remove = ScalarValue::Int32(Some(2));
790
791 assert_array_remove_all(input_list, expected_list, element_to_remove);
792 }
793
794 #[test]
795 fn test_array_remove_all_nullable() {
796 let input_list = Arc::new(ensure_field_nullability(
797 true,
798 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
799 Some(vec![
800 Some(1),
801 Some(2),
802 Some(2),
803 Some(3),
804 None,
805 Some(1),
806 Some(4),
807 ]),
808 Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
809 ]),
810 ));
811 let expected_list = ensure_field_nullability(
812 true,
813 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
814 Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
815 Some(vec![Some(42), None, Some(63)]),
816 ]),
817 );
818
819 let element_to_remove = ScalarValue::Int32(Some(2));
820
821 assert_array_remove_all(input_list, expected_list, element_to_remove);
822 }
823
824 fn assert_array_remove_all(
825 input_list: ArrayRef,
826 expected_list: GenericListArray<i32>,
827 element_to_remove: ScalarValue,
828 ) {
829 assert_eq!(input_list.data_type(), expected_list.data_type());
830 assert_eq!(expected_list.value_type(), element_to_remove.data_type());
831 let input_list_len = input_list.len();
832 let input_list_data_type = input_list.data_type().clone();
833
834 let udf = ArrayRemoveAll::new();
835 let args_fields = vec![
836 Arc::new(Field::new("num", input_list.data_type().clone(), false)),
837 Arc::new(Field::new(
838 "el",
839 element_to_remove.data_type(),
840 element_to_remove.is_null(),
841 )),
842 ];
843 let scalar_args = vec![None, Some(&element_to_remove)];
844
845 let return_field = udf
846 .return_field_from_args(ReturnFieldArgs {
847 arg_fields: &args_fields,
848 scalar_arguments: &scalar_args,
849 })
850 .unwrap();
851
852 let result = udf
853 .invoke_with_args(ScalarFunctionArgs {
854 args: vec![
855 ColumnarValue::Array(input_list),
856 ColumnarValue::Scalar(element_to_remove),
857 ],
858 arg_fields: args_fields,
859 number_rows: input_list_len,
860 return_field,
861 config_options: Arc::new(Default::default()),
862 })
863 .unwrap();
864
865 assert_eq!(result.data_type(), input_list_data_type);
866 match result {
867 ColumnarValue::Array(array) => {
868 let result_list = array.as_list::<i32>();
869 assert_eq!(result_list, &expected_list);
870 }
871 _ => panic!("Expected ColumnarValue::Array"),
872 }
873 }
874}