datafusion_functions_nested/
remove.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions.
19
20use crate::utils;
21use crate::utils::make_scalar_function;
22use arrow::array::{
23    cast::AsArray, new_empty_array, Array, ArrayRef, BooleanArray, GenericListArray,
24    OffsetSizeTrait,
25};
26use arrow::buffer::OffsetBuffer;
27use arrow::datatypes::{DataType, Field};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::utils::ListCoercion;
30use datafusion_common::{exec_err, utils::take_function_args, Result};
31use datafusion_expr::{
32    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
33    ScalarUDFImpl, Signature, TypeSignature, Volatility,
34};
35use datafusion_macros::user_doc;
36use std::any::Any;
37use std::sync::Arc;
38
39make_udf_expr_and_func!(
40    ArrayRemove,
41    array_remove,
42    array element,
43    "removes the first element from the array equal to the given value.",
44    array_remove_udf
45);
46
47#[user_doc(
48    doc_section(label = "Array Functions"),
49    description = "Removes the first element from the array equal to the given value.",
50    syntax_example = "array_remove(array, element)",
51    sql_example = r#"```sql
52> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
53+----------------------------------------------+
54| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
55+----------------------------------------------+
56| [1, 2, 3, 2, 1, 4]                           |
57+----------------------------------------------+
58```"#,
59    argument(
60        name = "array",
61        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62    ),
63    argument(
64        name = "element",
65        description = "Element to be removed from the array."
66    )
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct ArrayRemove {
70    signature: Signature,
71    aliases: Vec<String>,
72}
73
74impl Default for ArrayRemove {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl ArrayRemove {
81    pub fn new() -> Self {
82        Self {
83            signature: Signature::array_and_element(Volatility::Immutable),
84            aliases: vec!["list_remove".to_string()],
85        }
86    }
87}
88
89impl ScalarUDFImpl for ArrayRemove {
90    fn as_any(&self) -> &dyn Any {
91        self
92    }
93
94    fn name(&self) -> &str {
95        "array_remove"
96    }
97
98    fn signature(&self) -> &Signature {
99        &self.signature
100    }
101
102    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
103        Ok(arg_types[0].clone())
104    }
105
106    fn invoke_with_args(
107        &self,
108        args: datafusion_expr::ScalarFunctionArgs,
109    ) -> Result<ColumnarValue> {
110        make_scalar_function(array_remove_inner)(&args.args)
111    }
112
113    fn aliases(&self) -> &[String] {
114        &self.aliases
115    }
116
117    fn documentation(&self) -> Option<&Documentation> {
118        self.doc()
119    }
120}
121
122make_udf_expr_and_func!(
123    ArrayRemoveN,
124    array_remove_n,
125    array element max,
126    "removes the first `max` elements from the array equal to the given value.",
127    array_remove_n_udf
128);
129
130#[user_doc(
131    doc_section(label = "Array Functions"),
132    description = "Removes the first `max` elements from the array equal to the given value.",
133    syntax_example = "array_remove_n(array, element, max))",
134    sql_example = r#"```sql
135> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
136+---------------------------------------------------------+
137| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
138+---------------------------------------------------------+
139| [1, 3, 2, 1, 4]                                         |
140+---------------------------------------------------------+
141```"#,
142    argument(
143        name = "array",
144        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
145    ),
146    argument(
147        name = "element",
148        description = "Element to be removed from the array."
149    ),
150    argument(name = "max", description = "Number of first occurrences to remove.")
151)]
152#[derive(Debug, PartialEq, Eq, Hash)]
153pub(super) struct ArrayRemoveN {
154    signature: Signature,
155    aliases: Vec<String>,
156}
157
158impl ArrayRemoveN {
159    pub fn new() -> Self {
160        Self {
161            signature: Signature::new(
162                TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
163                    arguments: vec![
164                        ArrayFunctionArgument::Array,
165                        ArrayFunctionArgument::Element,
166                        ArrayFunctionArgument::Index,
167                    ],
168                    array_coercion: Some(ListCoercion::FixedSizedListToList),
169                }),
170                Volatility::Immutable,
171            ),
172            aliases: vec!["list_remove_n".to_string()],
173        }
174    }
175}
176
177impl ScalarUDFImpl for ArrayRemoveN {
178    fn as_any(&self) -> &dyn Any {
179        self
180    }
181
182    fn name(&self) -> &str {
183        "array_remove_n"
184    }
185
186    fn signature(&self) -> &Signature {
187        &self.signature
188    }
189
190    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
191        Ok(arg_types[0].clone())
192    }
193
194    fn invoke_with_args(
195        &self,
196        args: datafusion_expr::ScalarFunctionArgs,
197    ) -> Result<ColumnarValue> {
198        make_scalar_function(array_remove_n_inner)(&args.args)
199    }
200
201    fn aliases(&self) -> &[String] {
202        &self.aliases
203    }
204
205    fn documentation(&self) -> Option<&Documentation> {
206        self.doc()
207    }
208}
209
210make_udf_expr_and_func!(
211    ArrayRemoveAll,
212    array_remove_all,
213    array element,
214    "removes all elements from the array equal to the given value.",
215    array_remove_all_udf
216);
217
218#[user_doc(
219    doc_section(label = "Array Functions"),
220    description = "Removes all elements from the array equal to the given value.",
221    syntax_example = "array_remove_all(array, element)",
222    sql_example = r#"```sql
223> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
224+--------------------------------------------------+
225| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
226+--------------------------------------------------+
227| [1, 3, 1, 4]                                     |
228+--------------------------------------------------+
229```"#,
230    argument(
231        name = "array",
232        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
233    ),
234    argument(
235        name = "element",
236        description = "Element to be removed from the array."
237    )
238)]
239#[derive(Debug, PartialEq, Eq, Hash)]
240pub(super) struct ArrayRemoveAll {
241    signature: Signature,
242    aliases: Vec<String>,
243}
244
245impl ArrayRemoveAll {
246    pub fn new() -> Self {
247        Self {
248            signature: Signature::array_and_element(Volatility::Immutable),
249            aliases: vec!["list_remove_all".to_string()],
250        }
251    }
252}
253
254impl ScalarUDFImpl for ArrayRemoveAll {
255    fn as_any(&self) -> &dyn Any {
256        self
257    }
258
259    fn name(&self) -> &str {
260        "array_remove_all"
261    }
262
263    fn signature(&self) -> &Signature {
264        &self.signature
265    }
266
267    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
268        Ok(arg_types[0].clone())
269    }
270
271    fn invoke_with_args(
272        &self,
273        args: datafusion_expr::ScalarFunctionArgs,
274    ) -> Result<ColumnarValue> {
275        make_scalar_function(array_remove_all_inner)(&args.args)
276    }
277
278    fn aliases(&self) -> &[String] {
279        &self.aliases
280    }
281
282    fn documentation(&self) -> Option<&Documentation> {
283        self.doc()
284    }
285}
286
287/// Array_remove SQL function
288pub fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
289    let [array, element] = take_function_args("array_remove", args)?;
290
291    let arr_n = vec![1; array.len()];
292    array_remove_internal(array, element, arr_n)
293}
294
295/// Array_remove_n SQL function
296pub fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
297    let [array, element, max] = take_function_args("array_remove_n", args)?;
298
299    let arr_n = as_int64_array(max)?.values().to_vec();
300    array_remove_internal(array, element, arr_n)
301}
302
303/// Array_remove_all SQL function
304pub fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
305    let [array, element] = take_function_args("array_remove_all", args)?;
306
307    let arr_n = vec![i64::MAX; array.len()];
308    array_remove_internal(array, element, arr_n)
309}
310
311fn array_remove_internal(
312    array: &ArrayRef,
313    element_array: &ArrayRef,
314    arr_n: Vec<i64>,
315) -> Result<ArrayRef> {
316    match array.data_type() {
317        DataType::List(_) => {
318            let list_array = array.as_list::<i32>();
319            general_remove::<i32>(list_array, element_array, arr_n)
320        }
321        DataType::LargeList(_) => {
322            let list_array = array.as_list::<i64>();
323            general_remove::<i64>(list_array, element_array, arr_n)
324        }
325        array_type => {
326            exec_err!("array_remove_all does not support type '{array_type}'.")
327        }
328    }
329}
330
331/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
332/// of `element_array[i]`.
333///
334/// The type of each **element** in `list_array` must be the same as the type of
335/// `element_array`. This function also handles nested arrays
336/// ([`arrow::array::ListArray`] of [`arrow::array::ListArray`]s)
337///
338/// For example, when called to remove a list array (where each element is a
339/// list of int32s, the second argument are int32 arrays, and the
340/// third argument is the number of occurrences to remove
341///
342/// ```text
343/// general_remove(
344///   [1, 2, 3, 2], 2, 1    ==> [1, 3, 2]   (only the first 2 is removed)
345///   [4, 5, 6, 5], 5, 2    ==> [4, 6]  (both 5s are removed)
346/// )
347/// ```
348fn general_remove<OffsetSize: OffsetSizeTrait>(
349    list_array: &GenericListArray<OffsetSize>,
350    element_array: &ArrayRef,
351    arr_n: Vec<i64>,
352) -> Result<ArrayRef> {
353    let data_type = list_array.value_type();
354    let mut new_values = vec![];
355    // Build up the offsets for the final output array
356    let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
357    offsets.push(OffsetSize::zero());
358
359    // n is the number of elements to remove in this row
360    for (row_index, (list_array_row, n)) in
361        list_array.iter().zip(arr_n.iter()).enumerate()
362    {
363        match list_array_row {
364            Some(list_array_row) => {
365                let eq_array = utils::compare_element_to_list(
366                    &list_array_row,
367                    element_array,
368                    row_index,
369                    false,
370                )?;
371
372                // We need to keep at most first n elements as `false`, which represent the elements to remove.
373                let eq_array = if eq_array.false_count() < *n as usize {
374                    eq_array
375                } else {
376                    let mut count = 0;
377                    eq_array
378                        .iter()
379                        .map(|e| {
380                            // Keep first n `false` elements, and reverse other elements to `true`.
381                            if let Some(false) = e {
382                                if count < *n {
383                                    count += 1;
384                                    e
385                                } else {
386                                    Some(true)
387                                }
388                            } else {
389                                e
390                            }
391                        })
392                        .collect::<BooleanArray>()
393                };
394
395                let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?;
396                offsets.push(
397                    offsets[row_index] + OffsetSize::usize_as(filtered_array.len()),
398                );
399                new_values.push(filtered_array);
400            }
401            None => {
402                // Null element results in a null row (no new offsets)
403                offsets.push(offsets[row_index]);
404            }
405        }
406    }
407
408    let values = if new_values.is_empty() {
409        new_empty_array(&data_type)
410    } else {
411        let new_values = new_values.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
412        arrow::compute::concat(&new_values)?
413    };
414
415    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
416        Arc::new(Field::new_list_field(data_type, true)),
417        OffsetBuffer::new(offsets.into()),
418        values,
419        list_array.nulls().cloned(),
420    )?))
421}