datafusion_functions_nested/
replace.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_replace, array_replace_n and array_replace_all functions.
19
20use arrow::array::{
21    new_null_array, Array, ArrayRef, AsArray, Capacities, GenericListArray,
22    MutableArrayData, NullBufferBuilder, OffsetSizeTrait,
23};
24use arrow::datatypes::{DataType, Field};
25
26use arrow::buffer::OffsetBuffer;
27use datafusion_common::cast::as_int64_array;
28use datafusion_common::utils::ListCoercion;
29use datafusion_common::{exec_err, utils::take_function_args, Result};
30use datafusion_expr::{
31    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
32    ScalarUDFImpl, Signature, TypeSignature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36use crate::utils::compare_element_to_list;
37use crate::utils::make_scalar_function;
38
39use std::any::Any;
40use std::sync::Arc;
41
42// Create static instances of ScalarUDFs for each function
43make_udf_expr_and_func!(ArrayReplace,
44    array_replace,
45    array from to,
46    "replaces the first occurrence of the specified element with another specified element.",
47    array_replace_udf
48);
49make_udf_expr_and_func!(ArrayReplaceN,
50    array_replace_n,
51    array from to max,
52    "replaces the first `max` occurrences of the specified element with another specified element.",
53    array_replace_n_udf
54);
55make_udf_expr_and_func!(ArrayReplaceAll,
56    array_replace_all,
57    array from to,
58    "replaces all occurrences of the specified element with another specified element.",
59    array_replace_all_udf
60);
61
62#[user_doc(
63    doc_section(label = "Array Functions"),
64    description = "Replaces the first occurrence of the specified element with another specified element.",
65    syntax_example = "array_replace(array, from, to)",
66    sql_example = r#"```sql
67> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5);
68+--------------------------------------------------------+
69| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
70+--------------------------------------------------------+
71| [1, 5, 2, 3, 2, 1, 4]                                  |
72+--------------------------------------------------------+
73```"#,
74    argument(
75        name = "array",
76        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
77    ),
78    argument(name = "from", description = "Initial element."),
79    argument(name = "to", description = "Final element.")
80)]
81#[derive(Debug, PartialEq, Eq, Hash)]
82pub struct ArrayReplace {
83    signature: Signature,
84    aliases: Vec<String>,
85}
86
87impl Default for ArrayReplace {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl ArrayReplace {
94    pub fn new() -> Self {
95        Self {
96            signature: Signature {
97                type_signature: TypeSignature::ArraySignature(
98                    ArrayFunctionSignature::Array {
99                        arguments: vec![
100                            ArrayFunctionArgument::Array,
101                            ArrayFunctionArgument::Element,
102                            ArrayFunctionArgument::Element,
103                        ],
104                        array_coercion: Some(ListCoercion::FixedSizedListToList),
105                    },
106                ),
107                volatility: Volatility::Immutable,
108                parameter_names: None,
109            },
110            aliases: vec![String::from("list_replace")],
111        }
112    }
113}
114
115impl ScalarUDFImpl for ArrayReplace {
116    fn as_any(&self) -> &dyn Any {
117        self
118    }
119
120    fn name(&self) -> &str {
121        "array_replace"
122    }
123
124    fn signature(&self) -> &Signature {
125        &self.signature
126    }
127
128    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
129        Ok(args[0].clone())
130    }
131
132    fn invoke_with_args(
133        &self,
134        args: datafusion_expr::ScalarFunctionArgs,
135    ) -> Result<ColumnarValue> {
136        make_scalar_function(array_replace_inner)(&args.args)
137    }
138
139    fn aliases(&self) -> &[String] {
140        &self.aliases
141    }
142
143    fn documentation(&self) -> Option<&Documentation> {
144        self.doc()
145    }
146}
147
148#[user_doc(
149    doc_section(label = "Array Functions"),
150    description = "Replaces the first `max` occurrences of the specified element with another specified element.",
151    syntax_example = "array_replace_n(array, from, to, max)",
152    sql_example = r#"```sql
153> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2);
154+-------------------------------------------------------------------+
155| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) |
156+-------------------------------------------------------------------+
157| [1, 5, 5, 3, 2, 1, 4]                                             |
158+-------------------------------------------------------------------+
159```"#,
160    argument(
161        name = "array",
162        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
163    ),
164    argument(name = "from", description = "Initial element."),
165    argument(name = "to", description = "Final element."),
166    argument(name = "max", description = "Number of first occurrences to replace.")
167)]
168#[derive(Debug, PartialEq, Eq, Hash)]
169pub(super) struct ArrayReplaceN {
170    signature: Signature,
171    aliases: Vec<String>,
172}
173
174impl ArrayReplaceN {
175    pub fn new() -> Self {
176        Self {
177            signature: Signature {
178                type_signature: TypeSignature::ArraySignature(
179                    ArrayFunctionSignature::Array {
180                        arguments: vec![
181                            ArrayFunctionArgument::Array,
182                            ArrayFunctionArgument::Element,
183                            ArrayFunctionArgument::Element,
184                            ArrayFunctionArgument::Index,
185                        ],
186                        array_coercion: Some(ListCoercion::FixedSizedListToList),
187                    },
188                ),
189                volatility: Volatility::Immutable,
190                parameter_names: None,
191            },
192            aliases: vec![String::from("list_replace_n")],
193        }
194    }
195}
196
197impl ScalarUDFImpl for ArrayReplaceN {
198    fn as_any(&self) -> &dyn Any {
199        self
200    }
201
202    fn name(&self) -> &str {
203        "array_replace_n"
204    }
205
206    fn signature(&self) -> &Signature {
207        &self.signature
208    }
209
210    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
211        Ok(args[0].clone())
212    }
213
214    fn invoke_with_args(
215        &self,
216        args: datafusion_expr::ScalarFunctionArgs,
217    ) -> Result<ColumnarValue> {
218        make_scalar_function(array_replace_n_inner)(&args.args)
219    }
220
221    fn aliases(&self) -> &[String] {
222        &self.aliases
223    }
224
225    fn documentation(&self) -> Option<&Documentation> {
226        self.doc()
227    }
228}
229
230#[user_doc(
231    doc_section(label = "Array Functions"),
232    description = "Replaces all occurrences of the specified element with another specified element.",
233    syntax_example = "array_replace_all(array, from, to)",
234    sql_example = r#"```sql
235> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5);
236+------------------------------------------------------------+
237| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
238+------------------------------------------------------------+
239| [1, 5, 5, 3, 5, 1, 4]                                      |
240+------------------------------------------------------------+
241```"#,
242    argument(
243        name = "array",
244        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
245    ),
246    argument(name = "from", description = "Initial element."),
247    argument(name = "to", description = "Final element.")
248)]
249#[derive(Debug, PartialEq, Eq, Hash)]
250pub(super) struct ArrayReplaceAll {
251    signature: Signature,
252    aliases: Vec<String>,
253}
254
255impl ArrayReplaceAll {
256    pub fn new() -> Self {
257        Self {
258            signature: Signature {
259                type_signature: TypeSignature::ArraySignature(
260                    ArrayFunctionSignature::Array {
261                        arguments: vec![
262                            ArrayFunctionArgument::Array,
263                            ArrayFunctionArgument::Element,
264                            ArrayFunctionArgument::Element,
265                        ],
266                        array_coercion: Some(ListCoercion::FixedSizedListToList),
267                    },
268                ),
269                volatility: Volatility::Immutable,
270                parameter_names: None,
271            },
272            aliases: vec![String::from("list_replace_all")],
273        }
274    }
275}
276
277impl ScalarUDFImpl for ArrayReplaceAll {
278    fn as_any(&self) -> &dyn Any {
279        self
280    }
281
282    fn name(&self) -> &str {
283        "array_replace_all"
284    }
285
286    fn signature(&self) -> &Signature {
287        &self.signature
288    }
289
290    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
291        Ok(args[0].clone())
292    }
293
294    fn invoke_with_args(
295        &self,
296        args: datafusion_expr::ScalarFunctionArgs,
297    ) -> Result<ColumnarValue> {
298        make_scalar_function(array_replace_all_inner)(&args.args)
299    }
300
301    fn aliases(&self) -> &[String] {
302        &self.aliases
303    }
304
305    fn documentation(&self) -> Option<&Documentation> {
306        self.doc()
307    }
308}
309
310/// For each element of `list_array[i]`, replaces up to `arr_n[i]`  occurrences
311/// of `from_array[i]`, `to_array[i]`.
312///
313/// The type of each **element** in `list_array` must be the same as the type of
314/// `from_array` and `to_array`. This function also handles nested arrays
315/// (\[`ListArray`\] of \[`ListArray`\]s)
316///
317/// For example, when called to replace a list array (where each element is a
318/// list of int32s, the second and third argument are int32 arrays, and the
319/// fourth argument is the number of occurrences to replace
320///
321/// ```text
322/// general_replace(
323///   [1, 2, 3, 2], 2, 10, 1    ==> [1, 10, 3, 2]   (only the first 2 is replaced)
324///   [4, 5, 6, 5], 5, 20, 2    ==> [4, 20, 6, 20]  (both 5s are replaced)
325/// )
326/// ```
327fn general_replace<O: OffsetSizeTrait>(
328    list_array: &GenericListArray<O>,
329    from_array: &ArrayRef,
330    to_array: &ArrayRef,
331    arr_n: Vec<i64>,
332) -> Result<ArrayRef> {
333    // Build up the offsets for the final output array
334    let mut offsets: Vec<O> = vec![O::usize_as(0)];
335    let values = list_array.values();
336    let original_data = values.to_data();
337    let to_data = to_array.to_data();
338    let capacity = Capacities::Array(original_data.len());
339
340    // First array is the original array, second array is the element to replace with.
341    let mut mutable = MutableArrayData::with_capacities(
342        vec![&original_data, &to_data],
343        false,
344        capacity,
345    );
346
347    let mut valid = NullBufferBuilder::new(list_array.len());
348
349    for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
350        if list_array.is_null(row_index) {
351            offsets.push(offsets[row_index]);
352            valid.append_null();
353            continue;
354        }
355
356        let start = offset_window[0];
357        let end = offset_window[1];
358
359        let list_array_row = list_array.value(row_index);
360
361        // Compute all positions in list_row_array (that is itself an
362        // array) that are equal to `from_array_row`
363        let eq_array =
364            compare_element_to_list(&list_array_row, &from_array, row_index, true)?;
365
366        let original_idx = O::usize_as(0);
367        let replace_idx = O::usize_as(1);
368        let n = arr_n[row_index];
369        let mut counter = 0;
370
371        // All elements are false, no need to replace, just copy original data
372        if eq_array.false_count() == eq_array.len() {
373            mutable.extend(
374                original_idx.to_usize().unwrap(),
375                start.to_usize().unwrap(),
376                end.to_usize().unwrap(),
377            );
378            offsets.push(offsets[row_index] + (end - start));
379            valid.append_non_null();
380            continue;
381        }
382
383        for (i, to_replace) in eq_array.iter().enumerate() {
384            let i = O::usize_as(i);
385            if let Some(true) = to_replace {
386                mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1);
387                counter += 1;
388                if counter == n {
389                    // copy original data for any matches past n
390                    mutable.extend(
391                        original_idx.to_usize().unwrap(),
392                        (start + i).to_usize().unwrap() + 1,
393                        end.to_usize().unwrap(),
394                    );
395                    break;
396                }
397            } else {
398                // copy original data for false / null matches
399                mutable.extend(
400                    original_idx.to_usize().unwrap(),
401                    (start + i).to_usize().unwrap(),
402                    (start + i).to_usize().unwrap() + 1,
403                );
404            }
405        }
406
407        offsets.push(offsets[row_index] + (end - start));
408        valid.append_non_null();
409    }
410
411    let data = mutable.freeze();
412
413    Ok(Arc::new(GenericListArray::<O>::try_new(
414        Arc::new(Field::new_list_field(list_array.value_type(), true)),
415        OffsetBuffer::<O>::new(offsets.into()),
416        arrow::array::make_array(data),
417        valid.finish(),
418    )?))
419}
420
421pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
422    let [array, from, to] = take_function_args("array_replace", args)?;
423
424    // replace at most one occurrence for each element
425    let arr_n = vec![1; array.len()];
426    match array.data_type() {
427        DataType::List(_) => {
428            let list_array = array.as_list::<i32>();
429            general_replace::<i32>(list_array, from, to, arr_n)
430        }
431        DataType::LargeList(_) => {
432            let list_array = array.as_list::<i64>();
433            general_replace::<i64>(list_array, from, to, arr_n)
434        }
435        DataType::Null => Ok(new_null_array(array.data_type(), 1)),
436        array_type => exec_err!("array_replace does not support type '{array_type}'."),
437    }
438}
439
440pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
441    let [array, from, to, max] = take_function_args("array_replace_n", args)?;
442
443    // replace the specified number of occurrences
444    let arr_n = as_int64_array(max)?.values().to_vec();
445    match array.data_type() {
446        DataType::List(_) => {
447            let list_array = array.as_list::<i32>();
448            general_replace::<i32>(list_array, from, to, arr_n)
449        }
450        DataType::LargeList(_) => {
451            let list_array = array.as_list::<i64>();
452            general_replace::<i64>(list_array, from, to, arr_n)
453        }
454        DataType::Null => Ok(new_null_array(array.data_type(), 1)),
455        array_type => {
456            exec_err!("array_replace_n does not support type '{array_type}'.")
457        }
458    }
459}
460
461pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
462    let [array, from, to] = take_function_args("array_replace_all", args)?;
463
464    // replace all occurrences (up to "i64::MAX")
465    let arr_n = vec![i64::MAX; array.len()];
466    match array.data_type() {
467        DataType::List(_) => {
468            let list_array = array.as_list::<i32>();
469            general_replace::<i32>(list_array, from, to, arr_n)
470        }
471        DataType::LargeList(_) => {
472            let list_array = array.as_list::<i64>();
473            general_replace::<i64>(list_array, from, to, arr_n)
474        }
475        DataType::Null => Ok(new_null_array(array.data_type(), 1)),
476        array_type => {
477            exec_err!("array_replace_all does not support type '{array_type}'.")
478        }
479    }
480}