1use crate::utils;
21use crate::utils::make_scalar_function;
22use arrow::array::{
23 cast::AsArray, new_empty_array, Array, ArrayRef, BooleanArray, GenericListArray,
24 OffsetSizeTrait,
25};
26use arrow::buffer::OffsetBuffer;
27use arrow::datatypes::{DataType, Field};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::{exec_err, utils::take_function_args, Result};
30use datafusion_expr::{
31 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_macros::user_doc;
34use std::any::Any;
35use std::sync::Arc;
36
37make_udf_expr_and_func!(
38 ArrayRemove,
39 array_remove,
40 array element,
41 "removes the first element from the array equal to the given value.",
42 array_remove_udf
43);
44
45#[user_doc(
46 doc_section(label = "Array Functions"),
47 description = "Removes the first element from the array equal to the given value.",
48 syntax_example = "array_remove(array, element)",
49 sql_example = r#"```sql
50> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
51+----------------------------------------------+
52| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
53+----------------------------------------------+
54| [1, 2, 3, 2, 1, 4] |
55+----------------------------------------------+
56```"#,
57 argument(
58 name = "array",
59 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
60 ),
61 argument(
62 name = "element",
63 description = "Element to be removed from the array."
64 )
65)]
66#[derive(Debug)]
67pub struct ArrayRemove {
68 signature: Signature,
69 aliases: Vec<String>,
70}
71
72impl Default for ArrayRemove {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl ArrayRemove {
79 pub fn new() -> Self {
80 Self {
81 signature: Signature::array_and_element(Volatility::Immutable),
82 aliases: vec!["list_remove".to_string()],
83 }
84 }
85}
86
87impl ScalarUDFImpl for ArrayRemove {
88 fn as_any(&self) -> &dyn Any {
89 self
90 }
91
92 fn name(&self) -> &str {
93 "array_remove"
94 }
95
96 fn signature(&self) -> &Signature {
97 &self.signature
98 }
99
100 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
101 Ok(arg_types[0].clone())
102 }
103
104 fn invoke_with_args(
105 &self,
106 args: datafusion_expr::ScalarFunctionArgs,
107 ) -> Result<ColumnarValue> {
108 make_scalar_function(array_remove_inner)(&args.args)
109 }
110
111 fn aliases(&self) -> &[String] {
112 &self.aliases
113 }
114
115 fn documentation(&self) -> Option<&Documentation> {
116 self.doc()
117 }
118}
119
120make_udf_expr_and_func!(
121 ArrayRemoveN,
122 array_remove_n,
123 array element max,
124 "removes the first `max` elements from the array equal to the given value.",
125 array_remove_n_udf
126);
127
128#[user_doc(
129 doc_section(label = "Array Functions"),
130 description = "Removes the first `max` elements from the array equal to the given value.",
131 syntax_example = "array_remove_n(array, element, max))",
132 sql_example = r#"```sql
133> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
134+---------------------------------------------------------+
135| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
136+---------------------------------------------------------+
137| [1, 3, 2, 1, 4] |
138+---------------------------------------------------------+
139```"#,
140 argument(
141 name = "array",
142 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
143 ),
144 argument(
145 name = "element",
146 description = "Element to be removed from the array."
147 ),
148 argument(name = "max", description = "Number of first occurrences to remove.")
149)]
150#[derive(Debug)]
151pub(super) struct ArrayRemoveN {
152 signature: Signature,
153 aliases: Vec<String>,
154}
155
156impl ArrayRemoveN {
157 pub fn new() -> Self {
158 Self {
159 signature: Signature::any(3, Volatility::Immutable),
160 aliases: vec!["list_remove_n".to_string()],
161 }
162 }
163}
164
165impl ScalarUDFImpl for ArrayRemoveN {
166 fn as_any(&self) -> &dyn Any {
167 self
168 }
169
170 fn name(&self) -> &str {
171 "array_remove_n"
172 }
173
174 fn signature(&self) -> &Signature {
175 &self.signature
176 }
177
178 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
179 Ok(arg_types[0].clone())
180 }
181
182 fn invoke_with_args(
183 &self,
184 args: datafusion_expr::ScalarFunctionArgs,
185 ) -> Result<ColumnarValue> {
186 make_scalar_function(array_remove_n_inner)(&args.args)
187 }
188
189 fn aliases(&self) -> &[String] {
190 &self.aliases
191 }
192
193 fn documentation(&self) -> Option<&Documentation> {
194 self.doc()
195 }
196}
197
198make_udf_expr_and_func!(
199 ArrayRemoveAll,
200 array_remove_all,
201 array element,
202 "removes all elements from the array equal to the given value.",
203 array_remove_all_udf
204);
205
206#[user_doc(
207 doc_section(label = "Array Functions"),
208 description = "Removes all elements from the array equal to the given value.",
209 syntax_example = "array_remove_all(array, element)",
210 sql_example = r#"```sql
211> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
212+--------------------------------------------------+
213| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
214+--------------------------------------------------+
215| [1, 3, 1, 4] |
216+--------------------------------------------------+
217```"#,
218 argument(
219 name = "array",
220 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
221 ),
222 argument(
223 name = "element",
224 description = "Element to be removed from the array."
225 )
226)]
227#[derive(Debug)]
228pub(super) struct ArrayRemoveAll {
229 signature: Signature,
230 aliases: Vec<String>,
231}
232
233impl ArrayRemoveAll {
234 pub fn new() -> Self {
235 Self {
236 signature: Signature::array_and_element(Volatility::Immutable),
237 aliases: vec!["list_remove_all".to_string()],
238 }
239 }
240}
241
242impl ScalarUDFImpl for ArrayRemoveAll {
243 fn as_any(&self) -> &dyn Any {
244 self
245 }
246
247 fn name(&self) -> &str {
248 "array_remove_all"
249 }
250
251 fn signature(&self) -> &Signature {
252 &self.signature
253 }
254
255 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
256 Ok(arg_types[0].clone())
257 }
258
259 fn invoke_with_args(
260 &self,
261 args: datafusion_expr::ScalarFunctionArgs,
262 ) -> Result<ColumnarValue> {
263 make_scalar_function(array_remove_all_inner)(&args.args)
264 }
265
266 fn aliases(&self) -> &[String] {
267 &self.aliases
268 }
269
270 fn documentation(&self) -> Option<&Documentation> {
271 self.doc()
272 }
273}
274
275pub fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
277 let [array, element] = take_function_args("array_remove", args)?;
278
279 let arr_n = vec![1; array.len()];
280 array_remove_internal(array, element, arr_n)
281}
282
283pub fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
285 let [array, element, max] = take_function_args("array_remove_n", args)?;
286
287 let arr_n = as_int64_array(max)?.values().to_vec();
288 array_remove_internal(array, element, arr_n)
289}
290
291pub fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
293 let [array, element] = take_function_args("array_remove_all", args)?;
294
295 let arr_n = vec![i64::MAX; array.len()];
296 array_remove_internal(array, element, arr_n)
297}
298
299fn array_remove_internal(
300 array: &ArrayRef,
301 element_array: &ArrayRef,
302 arr_n: Vec<i64>,
303) -> Result<ArrayRef> {
304 match array.data_type() {
305 DataType::List(_) => {
306 let list_array = array.as_list::<i32>();
307 general_remove::<i32>(list_array, element_array, arr_n)
308 }
309 DataType::LargeList(_) => {
310 let list_array = array.as_list::<i64>();
311 general_remove::<i64>(list_array, element_array, arr_n)
312 }
313 array_type => {
314 exec_err!("array_remove_all does not support type '{array_type:?}'.")
315 }
316 }
317}
318
319fn general_remove<OffsetSize: OffsetSizeTrait>(
337 list_array: &GenericListArray<OffsetSize>,
338 element_array: &ArrayRef,
339 arr_n: Vec<i64>,
340) -> Result<ArrayRef> {
341 let data_type = list_array.value_type();
342 let mut new_values = vec![];
343 let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
345 offsets.push(OffsetSize::zero());
346
347 for (row_index, (list_array_row, n)) in
349 list_array.iter().zip(arr_n.iter()).enumerate()
350 {
351 match list_array_row {
352 Some(list_array_row) => {
353 let eq_array = utils::compare_element_to_list(
354 &list_array_row,
355 element_array,
356 row_index,
357 false,
358 )?;
359
360 let eq_array = if eq_array.false_count() < *n as usize {
362 eq_array
363 } else {
364 let mut count = 0;
365 eq_array
366 .iter()
367 .map(|e| {
368 if let Some(false) = e {
370 if count < *n {
371 count += 1;
372 e
373 } else {
374 Some(true)
375 }
376 } else {
377 e
378 }
379 })
380 .collect::<BooleanArray>()
381 };
382
383 let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?;
384 offsets.push(
385 offsets[row_index] + OffsetSize::usize_as(filtered_array.len()),
386 );
387 new_values.push(filtered_array);
388 }
389 None => {
390 offsets.push(offsets[row_index]);
392 }
393 }
394 }
395
396 let values = if new_values.is_empty() {
397 new_empty_array(&data_type)
398 } else {
399 let new_values = new_values.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
400 arrow::compute::concat(&new_values)?
401 };
402
403 Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
404 Arc::new(Field::new_list_field(data_type, true)),
405 OffsetBuffer::new(offsets.into()),
406 values,
407 list_array.nulls().cloned(),
408 )?))
409}