1use crate::utils;
21use crate::utils::make_scalar_function;
22use arrow::array::{
23 Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait,
24 cast::AsArray, make_array,
25};
26use arrow::buffer::{NullBuffer, OffsetBuffer};
27use arrow::datatypes::{DataType, FieldRef};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
31use datafusion_expr::{
32 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33 ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::sync::Arc;
37
38make_udf_expr_and_func!(
39 ArrayRemove,
40 array_remove,
41 array element,
42 "removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
43 array_remove_udf
44);
45
46#[user_doc(
47 doc_section(label = "Array Functions"),
48 description = "Removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
49 syntax_example = "array_remove(array, element)",
50 sql_example = r#"```sql
51> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
52+----------------------------------------------+
53| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
54+----------------------------------------------+
55| [1, 2, 3, 2, 1, 4] |
56+----------------------------------------------+
57
58> select array_remove([1, 2, NULL, 2, 4], 2);
59+---------------------------------------------------+
60| array_remove(List([1,2,NULL,2,4]),Int64(2)) |
61+---------------------------------------------------+
62| [1, NULL, 2, 4] |
63+---------------------------------------------------+
64```"#,
65 argument(
66 name = "array",
67 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
68 ),
69 argument(
70 name = "element",
71 description = "Element to be removed from the array."
72 )
73)]
74#[derive(Debug, PartialEq, Eq, Hash)]
75pub struct ArrayRemove {
76 signature: Signature,
77 aliases: Vec<String>,
78}
79
80impl Default for ArrayRemove {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl ArrayRemove {
87 pub fn new() -> Self {
88 Self {
89 signature: Signature::array_and_element(Volatility::Immutable),
90 aliases: vec!["list_remove".to_string()],
91 }
92 }
93}
94
95impl ScalarUDFImpl for ArrayRemove {
96 fn name(&self) -> &str {
97 "array_remove"
98 }
99
100 fn signature(&self) -> &Signature {
101 &self.signature
102 }
103
104 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
105 internal_err!("return_field_from_args should be used instead")
106 }
107
108 fn return_field_from_args(
109 &self,
110 args: datafusion_expr::ReturnFieldArgs,
111 ) -> Result<FieldRef> {
112 Ok(Arc::clone(&args.arg_fields[0]))
113 }
114
115 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
116 make_scalar_function(array_remove_inner)(&args.args)
117 }
118
119 fn aliases(&self) -> &[String] {
120 &self.aliases
121 }
122
123 fn documentation(&self) -> Option<&Documentation> {
124 self.doc()
125 }
126}
127
128make_udf_expr_and_func!(
129 ArrayRemoveN,
130 array_remove_n,
131 array element max,
132 "removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
133 array_remove_n_udf
134);
135
136#[user_doc(
137 doc_section(label = "Array Functions"),
138 description = "Removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
139 syntax_example = "array_remove_n(array, element, max)",
140 sql_example = r#"```sql
141> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
142+---------------------------------------------------------+
143| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
144+---------------------------------------------------------+
145| [1, 3, 2, 1, 4] |
146+---------------------------------------------------------+
147
148> select array_remove_n([1, 2, NULL, 2, 4], 2, 2);
149+----------------------------------------------------------+
150| array_remove_n(List([1,2,NULL,2,4]),Int64(2),Int64(2)) |
151+----------------------------------------------------------+
152| [1, NULL, 4] |
153+----------------------------------------------------------+
154```"#,
155 argument(
156 name = "array",
157 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
158 ),
159 argument(
160 name = "element",
161 description = "Element to be removed from the array."
162 ),
163 argument(name = "max", description = "Number of first occurrences to remove.")
164)]
165#[derive(Debug, PartialEq, Eq, Hash)]
166pub struct ArrayRemoveN {
167 signature: Signature,
168 aliases: Vec<String>,
169}
170
171impl Default for ArrayRemoveN {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177impl ArrayRemoveN {
178 pub fn new() -> Self {
179 Self {
180 signature: Signature::new(
181 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
182 arguments: vec![
183 ArrayFunctionArgument::Array,
184 ArrayFunctionArgument::Element,
185 ArrayFunctionArgument::Index,
186 ],
187 array_coercion: Some(ListCoercion::FixedSizedListToList),
188 }),
189 Volatility::Immutable,
190 ),
191 aliases: vec!["list_remove_n".to_string()],
192 }
193 }
194}
195
196impl ScalarUDFImpl for ArrayRemoveN {
197 fn name(&self) -> &str {
198 "array_remove_n"
199 }
200
201 fn signature(&self) -> &Signature {
202 &self.signature
203 }
204
205 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
206 internal_err!("return_field_from_args should be used instead")
207 }
208
209 fn return_field_from_args(
210 &self,
211 args: datafusion_expr::ReturnFieldArgs,
212 ) -> Result<FieldRef> {
213 Ok(Arc::clone(&args.arg_fields[0]))
214 }
215
216 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
217 make_scalar_function(array_remove_n_inner)(&args.args)
218 }
219
220 fn aliases(&self) -> &[String] {
221 &self.aliases
222 }
223
224 fn documentation(&self) -> Option<&Documentation> {
225 self.doc()
226 }
227}
228
229make_udf_expr_and_func!(
230 ArrayRemoveAll,
231 array_remove_all,
232 array element,
233 "removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
234 array_remove_all_udf
235);
236
237#[user_doc(
238 doc_section(label = "Array Functions"),
239 description = "Removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.",
240 syntax_example = "array_remove_all(array, element)",
241 sql_example = r#"```sql
242> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
243+--------------------------------------------------+
244| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
245+--------------------------------------------------+
246| [1, 3, 1, 4] |
247+--------------------------------------------------+
248
249> select array_remove_all([1, 2, NULL, 2, 4], 2);
250+-----------------------------------------------------+
251| array_remove_all(List([1,2,NULL,2,4]),Int64(2)) |
252+-----------------------------------------------------+
253| [1, NULL, 4] |
254+-----------------------------------------------------+
255```"#,
256 argument(
257 name = "array",
258 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
259 ),
260 argument(
261 name = "element",
262 description = "Element to be removed from the array."
263 )
264)]
265#[derive(Debug, PartialEq, Eq, Hash)]
266pub struct ArrayRemoveAll {
267 signature: Signature,
268 aliases: Vec<String>,
269}
270
271impl Default for ArrayRemoveAll {
272 fn default() -> Self {
273 Self::new()
274 }
275}
276
277impl ArrayRemoveAll {
278 pub fn new() -> Self {
279 Self {
280 signature: Signature::array_and_element(Volatility::Immutable),
281 aliases: vec!["list_remove_all".to_string()],
282 }
283 }
284}
285
286impl ScalarUDFImpl for ArrayRemoveAll {
287 fn name(&self) -> &str {
288 "array_remove_all"
289 }
290
291 fn signature(&self) -> &Signature {
292 &self.signature
293 }
294
295 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
296 internal_err!("return_field_from_args should be used instead")
297 }
298
299 fn return_field_from_args(
300 &self,
301 args: datafusion_expr::ReturnFieldArgs,
302 ) -> Result<FieldRef> {
303 Ok(Arc::clone(&args.arg_fields[0]))
304 }
305
306 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
307 make_scalar_function(array_remove_all_inner)(&args.args)
308 }
309
310 fn aliases(&self) -> &[String] {
311 &self.aliases
312 }
313
314 fn documentation(&self) -> Option<&Documentation> {
315 self.doc()
316 }
317}
318
319fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
320 let [array, element] = take_function_args("array_remove", args)?;
321
322 let arr_n = vec![1; array.len()];
323 array_remove_internal(array, element, &arr_n)
324}
325
326fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
327 let [array, element, max] = take_function_args("array_remove_n", args)?;
328
329 let arr_n = as_int64_array(max)?.values().to_vec();
330 array_remove_internal(array, element, &arr_n)
331}
332
333fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
334 let [array, element] = take_function_args("array_remove_all", args)?;
335
336 let arr_n = vec![i64::MAX; array.len()];
337 array_remove_internal(array, element, &arr_n)
338}
339
340fn array_remove_internal(
341 array: &ArrayRef,
342 element_array: &ArrayRef,
343 arr_n: &[i64],
344) -> Result<ArrayRef> {
345 match array.data_type() {
346 DataType::List(_) => {
347 let list_array = array.as_list::<i32>();
348 general_remove::<i32>(list_array, element_array, arr_n)
349 }
350 DataType::LargeList(_) => {
351 let list_array = array.as_list::<i64>();
352 general_remove::<i64>(list_array, element_array, arr_n)
353 }
354 array_type => {
355 exec_err!("array_remove_all does not support type '{array_type}'.")
356 }
357 }
358}
359
360fn general_remove<OffsetSize: OffsetSizeTrait>(
378 list_array: &GenericListArray<OffsetSize>,
379 element_array: &ArrayRef,
380 arr_n: &[i64],
381) -> Result<ArrayRef> {
382 let list_field = match list_array.data_type() {
383 DataType::List(field) | DataType::LargeList(field) => field,
384 _ => {
385 return exec_err!(
386 "Expected List or LargeList data type, got {:?}",
387 list_array.data_type()
388 );
389 }
390 };
391 let original_data = list_array.values().to_data();
392 let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
394 offsets.push(OffsetSize::zero());
395
396 let mut mutable = MutableArrayData::with_capacities(
397 vec![&original_data],
398 false,
399 Capacities::Array(original_data.len()),
400 );
401
402 let nulls = NullBuffer::union(list_array.nulls(), element_array.nulls());
404
405 for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
406 if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) {
407 offsets.push(offsets[row_index]);
408 continue;
409 }
410
411 let start = offset_window[0].to_usize().unwrap();
412 let end = offset_window[1].to_usize().unwrap();
413 let n = arr_n[row_index];
415
416 let eq_array = utils::compare_element_to_list(
418 &list_array.value(row_index),
419 element_array,
420 row_index,
421 false,
422 )?;
423
424 let num_to_remove = eq_array.false_count();
425
426 if num_to_remove == 0 {
428 mutable.extend(0, start, end);
429 offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start));
430 continue;
431 }
432
433 let max_removals = n.min(num_to_remove as i64);
435 let mut removed = 0i64;
436 let mut copied = 0usize;
437 let mut pending_batch_to_retain: Option<usize> = None;
439 for (i, keep) in eq_array.iter().enumerate() {
440 if keep == Some(false) && removed < max_removals {
441 if let Some(bs) = pending_batch_to_retain {
443 mutable.extend(0, start + bs, start + i);
444 copied += i - bs;
445 pending_batch_to_retain = None;
446 }
447 removed += 1;
448 } else if pending_batch_to_retain.is_none() {
449 pending_batch_to_retain = Some(i);
450 }
451 }
452
453 if let Some(bs) = pending_batch_to_retain {
455 mutable.extend(0, start + bs, start + eq_array.len());
456 copied += eq_array.len() - bs;
457 }
458
459 offsets.push(offsets[row_index] + OffsetSize::usize_as(copied));
460 }
461
462 let new_values = make_array(mutable.freeze());
463 Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
464 Arc::clone(list_field),
465 OffsetBuffer::new(offsets.into()),
466 new_values,
467 nulls,
468 )?))
469}
470
471#[cfg(test)]
472mod tests {
473 use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
474 use arrow::array::{
475 Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
476 };
477 use arrow::datatypes::{DataType, Field, Int32Type};
478 use datafusion_common::ScalarValue;
479 use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
480 use datafusion_expr_common::columnar_value::ColumnarValue;
481 use std::ops::Deref;
482 use std::sync::Arc;
483
484 #[test]
485 fn test_array_remove_nullability() {
486 for nullability in [true, false] {
487 for item_nullability in [true, false] {
488 let input_field = Arc::new(Field::new(
489 "num",
490 DataType::new_list(DataType::Int32, item_nullability),
491 nullability,
492 ));
493 let args_fields = vec![
494 Arc::clone(&input_field),
495 Arc::new(Field::new("a", DataType::Int32, false)),
496 ];
497 let scalar_args = vec![None, Some(&ScalarValue::Int32(Some(1)))];
498
499 let result = ArrayRemove::new()
500 .return_field_from_args(ReturnFieldArgs {
501 arg_fields: &args_fields,
502 scalar_arguments: &scalar_args,
503 })
504 .unwrap();
505
506 assert_eq!(result, input_field);
507 }
508 }
509 }
510
511 #[test]
512 fn test_array_remove_n_nullability() {
513 for nullability in [true, false] {
514 for item_nullability in [true, false] {
515 let input_field = Arc::new(Field::new(
516 "num",
517 DataType::new_list(DataType::Int32, item_nullability),
518 nullability,
519 ));
520 let args_fields = vec![
521 Arc::clone(&input_field),
522 Arc::new(Field::new("a", DataType::Int32, false)),
523 Arc::new(Field::new("b", DataType::Int64, false)),
524 ];
525 let scalar_args = vec![
526 None,
527 Some(&ScalarValue::Int32(Some(1))),
528 Some(&ScalarValue::Int64(Some(1))),
529 ];
530
531 let result = ArrayRemoveN::new()
532 .return_field_from_args(ReturnFieldArgs {
533 arg_fields: &args_fields,
534 scalar_arguments: &scalar_args,
535 })
536 .unwrap();
537
538 assert_eq!(result, input_field);
539 }
540 }
541 }
542
543 #[test]
544 fn test_array_remove_all_nullability() {
545 for nullability in [true, false] {
546 for item_nullability in [true, false] {
547 let input_field = Arc::new(Field::new(
548 "num",
549 DataType::new_list(DataType::Int32, item_nullability),
550 nullability,
551 ));
552 let result = ArrayRemoveAll::new()
553 .return_field_from_args(ReturnFieldArgs {
554 arg_fields: &[Arc::clone(&input_field)],
555 scalar_arguments: &[None],
556 })
557 .unwrap();
558
559 assert_eq!(result, input_field);
560 }
561 }
562 }
563
564 fn ensure_field_nullability<O: OffsetSizeTrait>(
565 field_nullable: bool,
566 list: GenericListArray<O>,
567 ) -> GenericListArray<O> {
568 let (field, offsets, values, nulls) = list.into_parts();
569
570 if field.is_nullable() == field_nullable {
571 return GenericListArray::new(field, offsets, values, nulls);
572 }
573 if !field_nullable {
574 assert_eq!(nulls, None);
575 }
576
577 let field = Arc::new(field.deref().clone().with_nullable(field_nullable));
578
579 GenericListArray::new(field, offsets, values, nulls)
580 }
581
582 #[test]
583 fn test_array_remove_non_nullable() {
584 let input_list = Arc::new(ensure_field_nullability(
585 false,
586 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
587 Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
588 Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
589 ]),
590 ));
591 let expected_list = ensure_field_nullability(
592 false,
593 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
594 Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
595 Some(([42, 55, 63, 2]).iter().copied().map(Some)),
596 ]),
597 );
598
599 let element_to_remove = ScalarValue::Int32(Some(2));
600
601 assert_array_remove(input_list, expected_list, element_to_remove);
602 }
603
604 #[test]
605 fn test_array_remove_nullable() {
606 let input_list = Arc::new(ensure_field_nullability(
607 true,
608 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
609 Some(vec![
610 Some(1),
611 Some(2),
612 Some(2),
613 Some(3),
614 None,
615 Some(1),
616 Some(4),
617 ]),
618 Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
619 ]),
620 ));
621 let expected_list = ensure_field_nullability(
622 true,
623 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
624 Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
625 Some(vec![Some(42), None, Some(63), Some(2)]),
626 ]),
627 );
628
629 let element_to_remove = ScalarValue::Int32(Some(2));
630
631 assert_array_remove(input_list, expected_list, element_to_remove);
632 }
633
634 fn assert_array_remove(
635 input_list: ArrayRef,
636 expected_list: GenericListArray<i32>,
637 element_to_remove: ScalarValue,
638 ) {
639 assert_eq!(input_list.data_type(), expected_list.data_type());
640 assert_eq!(expected_list.value_type(), element_to_remove.data_type());
641 let input_list_len = input_list.len();
642 let input_list_data_type = input_list.data_type().clone();
643
644 let udf = ArrayRemove::new();
645 let args_fields = vec![
646 Arc::new(Field::new("num", input_list.data_type().clone(), false)),
647 Arc::new(Field::new(
648 "el",
649 element_to_remove.data_type(),
650 element_to_remove.is_null(),
651 )),
652 ];
653 let scalar_args = vec![None, Some(&element_to_remove)];
654
655 let return_field = udf
656 .return_field_from_args(ReturnFieldArgs {
657 arg_fields: &args_fields,
658 scalar_arguments: &scalar_args,
659 })
660 .unwrap();
661
662 let result = udf
663 .invoke_with_args(ScalarFunctionArgs {
664 args: vec![
665 ColumnarValue::Array(input_list),
666 ColumnarValue::Scalar(element_to_remove),
667 ],
668 arg_fields: args_fields,
669 number_rows: input_list_len,
670 return_field,
671 config_options: Arc::new(Default::default()),
672 })
673 .unwrap();
674
675 assert_eq!(result.data_type(), input_list_data_type);
676 match result {
677 ColumnarValue::Array(array) => {
678 let result_list = array.as_list::<i32>();
679 assert_eq!(result_list, &expected_list);
680 }
681 _ => panic!("Expected ColumnarValue::Array"),
682 }
683 }
684
685 #[test]
686 fn test_array_remove_n_non_nullable() {
687 let input_list = Arc::new(ensure_field_nullability(
688 false,
689 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
690 Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
691 Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
692 ]),
693 ));
694 let expected_list = ensure_field_nullability(
695 false,
696 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
697 Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)),
698 Some(([42, 55, 63]).iter().copied().map(Some)),
699 ]),
700 );
701
702 let element_to_remove = ScalarValue::Int32(Some(2));
703
704 assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
705 }
706
707 #[test]
708 fn test_array_remove_n_nullable() {
709 let input_list = Arc::new(ensure_field_nullability(
710 true,
711 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
712 Some(vec![
713 Some(1),
714 Some(2),
715 Some(2),
716 Some(3),
717 None,
718 Some(1),
719 Some(4),
720 ]),
721 Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
722 ]),
723 ));
724 let expected_list = ensure_field_nullability(
725 true,
726 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
727 Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
728 Some(vec![Some(42), None, Some(63)]),
729 ]),
730 );
731
732 let element_to_remove = ScalarValue::Int32(Some(2));
733
734 assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
735 }
736
737 fn assert_array_remove_n(
738 input_list: ArrayRef,
739 expected_list: GenericListArray<i32>,
740 element_to_remove: ScalarValue,
741 n: i64,
742 ) {
743 assert_eq!(input_list.data_type(), expected_list.data_type());
744 assert_eq!(expected_list.value_type(), element_to_remove.data_type());
745 let input_list_len = input_list.len();
746 let input_list_data_type = input_list.data_type().clone();
747
748 let count_scalar = ScalarValue::Int64(Some(n));
749
750 let udf = ArrayRemoveN::new();
751 let args_fields = vec![
752 Arc::new(Field::new("num", input_list.data_type().clone(), false)),
753 Arc::new(Field::new(
754 "el",
755 element_to_remove.data_type(),
756 element_to_remove.is_null(),
757 )),
758 Arc::new(Field::new("count", DataType::Int64, false)),
759 ];
760 let scalar_args = vec![None, Some(&element_to_remove), Some(&count_scalar)];
761
762 let return_field = udf
763 .return_field_from_args(ReturnFieldArgs {
764 arg_fields: &args_fields,
765 scalar_arguments: &scalar_args,
766 })
767 .unwrap();
768
769 let result = udf
770 .invoke_with_args(ScalarFunctionArgs {
771 args: vec![
772 ColumnarValue::Array(input_list),
773 ColumnarValue::Scalar(element_to_remove),
774 ColumnarValue::Scalar(count_scalar),
775 ],
776 arg_fields: args_fields,
777 number_rows: input_list_len,
778 return_field,
779 config_options: Arc::new(Default::default()),
780 })
781 .unwrap();
782
783 assert_eq!(result.data_type(), input_list_data_type);
784 match result {
785 ColumnarValue::Array(array) => {
786 let result_list = array.as_list::<i32>();
787 assert_eq!(result_list, &expected_list);
788 }
789 _ => panic!("Expected ColumnarValue::Array"),
790 }
791 }
792
793 #[test]
794 fn test_array_remove_all_non_nullable() {
795 let input_list = Arc::new(ensure_field_nullability(
796 false,
797 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
798 Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
799 Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
800 ]),
801 ));
802 let expected_list = ensure_field_nullability(
803 false,
804 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
805 Some(([1, 3, 1, 4]).iter().copied().map(Some)),
806 Some(([42, 55, 63]).iter().copied().map(Some)),
807 ]),
808 );
809
810 let element_to_remove = ScalarValue::Int32(Some(2));
811
812 assert_array_remove_all(input_list, expected_list, element_to_remove);
813 }
814
815 #[test]
816 fn test_array_remove_all_nullable() {
817 let input_list = Arc::new(ensure_field_nullability(
818 true,
819 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
820 Some(vec![
821 Some(1),
822 Some(2),
823 Some(2),
824 Some(3),
825 None,
826 Some(1),
827 Some(4),
828 ]),
829 Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
830 ]),
831 ));
832 let expected_list = ensure_field_nullability(
833 true,
834 ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
835 Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
836 Some(vec![Some(42), None, Some(63)]),
837 ]),
838 );
839
840 let element_to_remove = ScalarValue::Int32(Some(2));
841
842 assert_array_remove_all(input_list, expected_list, element_to_remove);
843 }
844
845 fn assert_array_remove_all(
846 input_list: ArrayRef,
847 expected_list: GenericListArray<i32>,
848 element_to_remove: ScalarValue,
849 ) {
850 assert_eq!(input_list.data_type(), expected_list.data_type());
851 assert_eq!(expected_list.value_type(), element_to_remove.data_type());
852 let input_list_len = input_list.len();
853 let input_list_data_type = input_list.data_type().clone();
854
855 let udf = ArrayRemoveAll::new();
856 let args_fields = vec![
857 Arc::new(Field::new("num", input_list.data_type().clone(), false)),
858 Arc::new(Field::new(
859 "el",
860 element_to_remove.data_type(),
861 element_to_remove.is_null(),
862 )),
863 ];
864 let scalar_args = vec![None, Some(&element_to_remove)];
865
866 let return_field = udf
867 .return_field_from_args(ReturnFieldArgs {
868 arg_fields: &args_fields,
869 scalar_arguments: &scalar_args,
870 })
871 .unwrap();
872
873 let result = udf
874 .invoke_with_args(ScalarFunctionArgs {
875 args: vec![
876 ColumnarValue::Array(input_list),
877 ColumnarValue::Scalar(element_to_remove),
878 ],
879 arg_fields: args_fields,
880 number_rows: input_list_len,
881 return_field,
882 config_options: Arc::new(Default::default()),
883 })
884 .unwrap();
885
886 assert_eq!(result.data_type(), input_list_data_type);
887 match result {
888 ColumnarValue::Array(array) => {
889 let result_list = array.as_list::<i32>();
890 assert_eq!(result_list, &expected_list);
891 }
892 _ => panic!("Expected ColumnarValue::Array"),
893 }
894 }
895}