datafusion_functions_nested/
remove.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions.
19
20use crate::utils;
21use crate::utils::make_scalar_function;
22use arrow::array::{
23    cast::AsArray, new_empty_array, Array, ArrayRef, BooleanArray, GenericListArray,
24    OffsetSizeTrait,
25};
26use arrow::buffer::OffsetBuffer;
27use arrow::datatypes::{DataType, Field};
28use datafusion_common::cast::as_int64_array;
29use datafusion_common::{exec_err, utils::take_function_args, Result};
30use datafusion_expr::{
31    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
32};
33use datafusion_macros::user_doc;
34use std::any::Any;
35use std::sync::Arc;
36
37make_udf_expr_and_func!(
38    ArrayRemove,
39    array_remove,
40    array element,
41    "removes the first element from the array equal to the given value.",
42    array_remove_udf
43);
44
45#[user_doc(
46    doc_section(label = "Array Functions"),
47    description = "Removes the first element from the array equal to the given value.",
48    syntax_example = "array_remove(array, element)",
49    sql_example = r#"```sql
50> select array_remove([1, 2, 2, 3, 2, 1, 4], 2);
51+----------------------------------------------+
52| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) |
53+----------------------------------------------+
54| [1, 2, 3, 2, 1, 4]                           |
55+----------------------------------------------+
56```"#,
57    argument(
58        name = "array",
59        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
60    ),
61    argument(
62        name = "element",
63        description = "Element to be removed from the array."
64    )
65)]
66#[derive(Debug)]
67pub struct ArrayRemove {
68    signature: Signature,
69    aliases: Vec<String>,
70}
71
72impl Default for ArrayRemove {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl ArrayRemove {
79    pub fn new() -> Self {
80        Self {
81            signature: Signature::array_and_element(Volatility::Immutable),
82            aliases: vec!["list_remove".to_string()],
83        }
84    }
85}
86
87impl ScalarUDFImpl for ArrayRemove {
88    fn as_any(&self) -> &dyn Any {
89        self
90    }
91
92    fn name(&self) -> &str {
93        "array_remove"
94    }
95
96    fn signature(&self) -> &Signature {
97        &self.signature
98    }
99
100    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
101        Ok(arg_types[0].clone())
102    }
103
104    fn invoke_with_args(
105        &self,
106        args: datafusion_expr::ScalarFunctionArgs,
107    ) -> Result<ColumnarValue> {
108        make_scalar_function(array_remove_inner)(&args.args)
109    }
110
111    fn aliases(&self) -> &[String] {
112        &self.aliases
113    }
114
115    fn documentation(&self) -> Option<&Documentation> {
116        self.doc()
117    }
118}
119
120make_udf_expr_and_func!(
121    ArrayRemoveN,
122    array_remove_n,
123    array element max,
124    "removes the first `max` elements from the array equal to the given value.",
125    array_remove_n_udf
126);
127
128#[user_doc(
129    doc_section(label = "Array Functions"),
130    description = "Removes the first `max` elements from the array equal to the given value.",
131    syntax_example = "array_remove_n(array, element, max))",
132    sql_example = r#"```sql
133> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2);
134+---------------------------------------------------------+
135| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) |
136+---------------------------------------------------------+
137| [1, 3, 2, 1, 4]                                         |
138+---------------------------------------------------------+
139```"#,
140    argument(
141        name = "array",
142        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
143    ),
144    argument(
145        name = "element",
146        description = "Element to be removed from the array."
147    ),
148    argument(name = "max", description = "Number of first occurrences to remove.")
149)]
150#[derive(Debug)]
151pub(super) struct ArrayRemoveN {
152    signature: Signature,
153    aliases: Vec<String>,
154}
155
156impl ArrayRemoveN {
157    pub fn new() -> Self {
158        Self {
159            signature: Signature::any(3, Volatility::Immutable),
160            aliases: vec!["list_remove_n".to_string()],
161        }
162    }
163}
164
165impl ScalarUDFImpl for ArrayRemoveN {
166    fn as_any(&self) -> &dyn Any {
167        self
168    }
169
170    fn name(&self) -> &str {
171        "array_remove_n"
172    }
173
174    fn signature(&self) -> &Signature {
175        &self.signature
176    }
177
178    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
179        Ok(arg_types[0].clone())
180    }
181
182    fn invoke_with_args(
183        &self,
184        args: datafusion_expr::ScalarFunctionArgs,
185    ) -> Result<ColumnarValue> {
186        make_scalar_function(array_remove_n_inner)(&args.args)
187    }
188
189    fn aliases(&self) -> &[String] {
190        &self.aliases
191    }
192
193    fn documentation(&self) -> Option<&Documentation> {
194        self.doc()
195    }
196}
197
198make_udf_expr_and_func!(
199    ArrayRemoveAll,
200    array_remove_all,
201    array element,
202    "removes all elements from the array equal to the given value.",
203    array_remove_all_udf
204);
205
206#[user_doc(
207    doc_section(label = "Array Functions"),
208    description = "Removes all elements from the array equal to the given value.",
209    syntax_example = "array_remove_all(array, element)",
210    sql_example = r#"```sql
211> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2);
212+--------------------------------------------------+
213| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) |
214+--------------------------------------------------+
215| [1, 3, 1, 4]                                     |
216+--------------------------------------------------+
217```"#,
218    argument(
219        name = "array",
220        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
221    ),
222    argument(
223        name = "element",
224        description = "Element to be removed from the array."
225    )
226)]
227#[derive(Debug)]
228pub(super) struct ArrayRemoveAll {
229    signature: Signature,
230    aliases: Vec<String>,
231}
232
233impl ArrayRemoveAll {
234    pub fn new() -> Self {
235        Self {
236            signature: Signature::array_and_element(Volatility::Immutable),
237            aliases: vec!["list_remove_all".to_string()],
238        }
239    }
240}
241
242impl ScalarUDFImpl for ArrayRemoveAll {
243    fn as_any(&self) -> &dyn Any {
244        self
245    }
246
247    fn name(&self) -> &str {
248        "array_remove_all"
249    }
250
251    fn signature(&self) -> &Signature {
252        &self.signature
253    }
254
255    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
256        Ok(arg_types[0].clone())
257    }
258
259    fn invoke_with_args(
260        &self,
261        args: datafusion_expr::ScalarFunctionArgs,
262    ) -> Result<ColumnarValue> {
263        make_scalar_function(array_remove_all_inner)(&args.args)
264    }
265
266    fn aliases(&self) -> &[String] {
267        &self.aliases
268    }
269
270    fn documentation(&self) -> Option<&Documentation> {
271        self.doc()
272    }
273}
274
275/// Array_remove SQL function
276pub fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
277    let [array, element] = take_function_args("array_remove", args)?;
278
279    let arr_n = vec![1; array.len()];
280    array_remove_internal(array, element, arr_n)
281}
282
283/// Array_remove_n SQL function
284pub fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
285    let [array, element, max] = take_function_args("array_remove_n", args)?;
286
287    let arr_n = as_int64_array(max)?.values().to_vec();
288    array_remove_internal(array, element, arr_n)
289}
290
291/// Array_remove_all SQL function
292pub fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
293    let [array, element] = take_function_args("array_remove_all", args)?;
294
295    let arr_n = vec![i64::MAX; array.len()];
296    array_remove_internal(array, element, arr_n)
297}
298
299fn array_remove_internal(
300    array: &ArrayRef,
301    element_array: &ArrayRef,
302    arr_n: Vec<i64>,
303) -> Result<ArrayRef> {
304    match array.data_type() {
305        DataType::List(_) => {
306            let list_array = array.as_list::<i32>();
307            general_remove::<i32>(list_array, element_array, arr_n)
308        }
309        DataType::LargeList(_) => {
310            let list_array = array.as_list::<i64>();
311            general_remove::<i64>(list_array, element_array, arr_n)
312        }
313        array_type => {
314            exec_err!("array_remove_all does not support type '{array_type:?}'.")
315        }
316    }
317}
318
319/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
320/// of `element_array[i]`.
321///
322/// The type of each **element** in `list_array` must be the same as the type of
323/// `element_array`. This function also handles nested arrays
324/// ([`arrow::array::ListArray`] of [`arrow::array::ListArray`]s)
325///
326/// For example, when called to remove a list array (where each element is a
327/// list of int32s, the second argument are int32 arrays, and the
328/// third argument is the number of occurrences to remove
329///
330/// ```text
331/// general_remove(
332///   [1, 2, 3, 2], 2, 1    ==> [1, 3, 2]   (only the first 2 is removed)
333///   [4, 5, 6, 5], 5, 2    ==> [4, 6]  (both 5s are removed)
334/// )
335/// ```
336fn general_remove<OffsetSize: OffsetSizeTrait>(
337    list_array: &GenericListArray<OffsetSize>,
338    element_array: &ArrayRef,
339    arr_n: Vec<i64>,
340) -> Result<ArrayRef> {
341    let data_type = list_array.value_type();
342    let mut new_values = vec![];
343    // Build up the offsets for the final output array
344    let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
345    offsets.push(OffsetSize::zero());
346
347    // n is the number of elements to remove in this row
348    for (row_index, (list_array_row, n)) in
349        list_array.iter().zip(arr_n.iter()).enumerate()
350    {
351        match list_array_row {
352            Some(list_array_row) => {
353                let eq_array = utils::compare_element_to_list(
354                    &list_array_row,
355                    element_array,
356                    row_index,
357                    false,
358                )?;
359
360                // We need to keep at most first n elements as `false`, which represent the elements to remove.
361                let eq_array = if eq_array.false_count() < *n as usize {
362                    eq_array
363                } else {
364                    let mut count = 0;
365                    eq_array
366                        .iter()
367                        .map(|e| {
368                            // Keep first n `false` elements, and reverse other elements to `true`.
369                            if let Some(false) = e {
370                                if count < *n {
371                                    count += 1;
372                                    e
373                                } else {
374                                    Some(true)
375                                }
376                            } else {
377                                e
378                            }
379                        })
380                        .collect::<BooleanArray>()
381                };
382
383                let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?;
384                offsets.push(
385                    offsets[row_index] + OffsetSize::usize_as(filtered_array.len()),
386                );
387                new_values.push(filtered_array);
388            }
389            None => {
390                // Null element results in a null row (no new offsets)
391                offsets.push(offsets[row_index]);
392            }
393        }
394    }
395
396    let values = if new_values.is_empty() {
397        new_empty_array(&data_type)
398    } else {
399        let new_values = new_values.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
400        arrow::compute::concat(&new_values)?
401    };
402
403    Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
404        Arc::new(Field::new_list_field(data_type, true)),
405        OffsetBuffer::new(offsets.into()),
406        values,
407        list_array.nulls().cloned(),
408    )?))
409}