1use crate::utils;
21use crate::utils::make_scalar_function;
22use arrow::array::{
23 cast::AsArray, new_empty_array, Array, ArrayRef, BooleanArray, GenericListArray,
24 OffsetSizeTrait,
25};
26use arrow::buffer::OffsetBuffer;
27use arrow::datatypes::{DataType, Field};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{exec_err, utils::take_function_args, Result};
31use datafusion_expr::{
32 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33 ScalarUDFImpl, Signature, TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::any::Any;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40 ArrayRemove,
41 array_remove,
42 array element,
43 "removes the first element from the array equal to the given value.",
44 array_remove_udf
45);
46
47#[user_doc(
48 doc_section(label = "Array Functions"),
49 description = "Removes the first element from the array equal to the given value.",
50 syntax_example = "array_remove(array, element)",
51 sql_example = r#"```sql
52> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
53+----------------------------------------------+
54| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
55+----------------------------------------------+
56| [1, 2, 3, 2, 1, 4] |
57+----------------------------------------------+
58```"#,
59 argument(
60 name = "array",
61 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62 ),
63 argument(
64 name = "element",
65 description = "Element to be removed from the array."
66 )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct ArrayRemove {
70 signature: Signature,
71 aliases: Vec<String>,
72}
73
74impl Default for ArrayRemove {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl ArrayRemove {
81 pub fn new() -> Self {
82 Self {
83 signature: Signature::array_and_element(Volatility::Immutable),
84 aliases: vec!["list_remove".to_string()],
85 }
86 }
87}
88
89impl ScalarUDFImpl for ArrayRemove {
90 fn as_any(&self) -> &dyn Any {
91 self
92 }
93
94 fn name(&self) -> &str {
95 "array_remove"
96 }
97
98 fn signature(&self) -> &Signature {
99 &self.signature
100 }
101
102 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
103 Ok(arg_types[0].clone())
104 }
105
106 fn invoke_with_args(
107 &self,
108 args: datafusion_expr::ScalarFunctionArgs,
109 ) -> Result<ColumnarValue> {
110 make_scalar_function(array_remove_inner)(&args.args)
111 }
112
113 fn aliases(&self) -> &[String] {
114 &self.aliases
115 }
116
117 fn documentation(&self) -> Option<&Documentation> {
118 self.doc()
119 }
120}
121
122make_udf_expr_and_func!(
123 ArrayRemoveN,
124 array_remove_n,
125 array element max,
126 "removes the first `max` elements from the array equal to the given value.",
127 array_remove_n_udf
128);
129
130#[user_doc(
131 doc_section(label = "Array Functions"),
132 description = "Removes the first `max` elements from the array equal to the given value.",
133 syntax_example = "array_remove_n(array, element, max))",
134 sql_example = r#"```sql
135> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
136+---------------------------------------------------------+
137| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
138+---------------------------------------------------------+
139| [1, 3, 2, 1, 4] |
140+---------------------------------------------------------+
141```"#,
142 argument(
143 name = "array",
144 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
145 ),
146 argument(
147 name = "element",
148 description = "Element to be removed from the array."
149 ),
150 argument(name = "max", description = "Number of first occurrences to remove.")
151)]
152#[derive(Debug, PartialEq, Eq, Hash)]
153pub(super) struct ArrayRemoveN {
154 signature: Signature,
155 aliases: Vec<String>,
156}
157
158impl ArrayRemoveN {
159 pub fn new() -> Self {
160 Self {
161 signature: Signature::new(
162 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
163 arguments: vec![
164 ArrayFunctionArgument::Array,
165 ArrayFunctionArgument::Element,
166 ArrayFunctionArgument::Index,
167 ],
168 array_coercion: Some(ListCoercion::FixedSizedListToList),
169 }),
170 Volatility::Immutable,
171 ),
172 aliases: vec!["list_remove_n".to_string()],
173 }
174 }
175}
176
177impl ScalarUDFImpl for ArrayRemoveN {
178 fn as_any(&self) -> &dyn Any {
179 self
180 }
181
182 fn name(&self) -> &str {
183 "array_remove_n"
184 }
185
186 fn signature(&self) -> &Signature {
187 &self.signature
188 }
189
190 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
191 Ok(arg_types[0].clone())
192 }
193
194 fn invoke_with_args(
195 &self,
196 args: datafusion_expr::ScalarFunctionArgs,
197 ) -> Result<ColumnarValue> {
198 make_scalar_function(array_remove_n_inner)(&args.args)
199 }
200
201 fn aliases(&self) -> &[String] {
202 &self.aliases
203 }
204
205 fn documentation(&self) -> Option<&Documentation> {
206 self.doc()
207 }
208}
209
210make_udf_expr_and_func!(
211 ArrayRemoveAll,
212 array_remove_all,
213 array element,
214 "removes all elements from the array equal to the given value.",
215 array_remove_all_udf
216);
217
218#[user_doc(
219 doc_section(label = "Array Functions"),
220 description = "Removes all elements from the array equal to the given value.",
221 syntax_example = "array_remove_all(array, element)",
222 sql_example = r#"```sql
223> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
224+--------------------------------------------------+
225| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
226+--------------------------------------------------+
227| [1, 3, 1, 4] |
228+--------------------------------------------------+
229```"#,
230 argument(
231 name = "array",
232 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
233 ),
234 argument(
235 name = "element",
236 description = "Element to be removed from the array."
237 )
238)]
239#[derive(Debug, PartialEq, Eq, Hash)]
240pub(super) struct ArrayRemoveAll {
241 signature: Signature,
242 aliases: Vec<String>,
243}
244
245impl ArrayRemoveAll {
246 pub fn new() -> Self {
247 Self {
248 signature: Signature::array_and_element(Volatility::Immutable),
249 aliases: vec!["list_remove_all".to_string()],
250 }
251 }
252}
253
254impl ScalarUDFImpl for ArrayRemoveAll {
255 fn as_any(&self) -> &dyn Any {
256 self
257 }
258
259 fn name(&self) -> &str {
260 "array_remove_all"
261 }
262
263 fn signature(&self) -> &Signature {
264 &self.signature
265 }
266
267 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
268 Ok(arg_types[0].clone())
269 }
270
271 fn invoke_with_args(
272 &self,
273 args: datafusion_expr::ScalarFunctionArgs,
274 ) -> Result<ColumnarValue> {
275 make_scalar_function(array_remove_all_inner)(&args.args)
276 }
277
278 fn aliases(&self) -> &[String] {
279 &self.aliases
280 }
281
282 fn documentation(&self) -> Option<&Documentation> {
283 self.doc()
284 }
285}
286
287pub fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
289 let [array, element] = take_function_args("array_remove", args)?;
290
291 let arr_n = vec![1; array.len()];
292 array_remove_internal(array, element, arr_n)
293}
294
295pub fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
297 let [array, element, max] = take_function_args("array_remove_n", args)?;
298
299 let arr_n = as_int64_array(max)?.values().to_vec();
300 array_remove_internal(array, element, arr_n)
301}
302
303pub fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
305 let [array, element] = take_function_args("array_remove_all", args)?;
306
307 let arr_n = vec![i64::MAX; array.len()];
308 array_remove_internal(array, element, arr_n)
309}
310
311fn array_remove_internal(
312 array: &ArrayRef,
313 element_array: &ArrayRef,
314 arr_n: Vec<i64>,
315) -> Result<ArrayRef> {
316 match array.data_type() {
317 DataType::List(_) => {
318 let list_array = array.as_list::<i32>();
319 general_remove::<i32>(list_array, element_array, arr_n)
320 }
321 DataType::LargeList(_) => {
322 let list_array = array.as_list::<i64>();
323 general_remove::<i64>(list_array, element_array, arr_n)
324 }
325 array_type => {
326 exec_err!("array_remove_all does not support type '{array_type}'.")
327 }
328 }
329}
330
331fn general_remove<OffsetSize: OffsetSizeTrait>(
349 list_array: &GenericListArray<OffsetSize>,
350 element_array: &ArrayRef,
351 arr_n: Vec<i64>,
352) -> Result<ArrayRef> {
353 let data_type = list_array.value_type();
354 let mut new_values = vec![];
355 let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
357 offsets.push(OffsetSize::zero());
358
359 for (row_index, (list_array_row, n)) in
361 list_array.iter().zip(arr_n.iter()).enumerate()
362 {
363 match list_array_row {
364 Some(list_array_row) => {
365 let eq_array = utils::compare_element_to_list(
366 &list_array_row,
367 element_array,
368 row_index,
369 false,
370 )?;
371
372 let eq_array = if eq_array.false_count() < *n as usize {
374 eq_array
375 } else {
376 let mut count = 0;
377 eq_array
378 .iter()
379 .map(|e| {
380 if let Some(false) = e {
382 if count < *n {
383 count += 1;
384 e
385 } else {
386 Some(true)
387 }
388 } else {
389 e
390 }
391 })
392 .collect::<BooleanArray>()
393 };
394
395 let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?;
396 offsets.push(
397 offsets[row_index] + OffsetSize::usize_as(filtered_array.len()),
398 );
399 new_values.push(filtered_array);
400 }
401 None => {
402 offsets.push(offsets[row_index]);
404 }
405 }
406 }
407
408 let values = if new_values.is_empty() {
409 new_empty_array(&data_type)
410 } else {
411 let new_values = new_values.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
412 arrow::compute::concat(&new_values)?
413 };
414
415 Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
416 Arc::new(Field::new_list_field(data_type, true)),
417 OffsetBuffer::new(offsets.into()),
418 values,
419 list_array.nulls().cloned(),
420 )?))
421}