datafusion_functions_nested/
replace.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_replace, array_replace_n and array_replace_all functions.
19
20use arrow::array::{
21    new_null_array, Array, ArrayRef, AsArray, Capacities, GenericListArray,
22    MutableArrayData, NullBufferBuilder, OffsetSizeTrait,
23};
24use arrow::datatypes::{DataType, Field};
25
26use arrow::buffer::OffsetBuffer;
27use datafusion_common::cast::as_int64_array;
28use datafusion_common::utils::ListCoercion;
29use datafusion_common::{exec_err, utils::take_function_args, Result};
30use datafusion_expr::{
31    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
32    ScalarUDFImpl, Signature, TypeSignature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36use crate::utils::compare_element_to_list;
37use crate::utils::make_scalar_function;
38
39use std::any::Any;
40use std::sync::Arc;
41
42// Create static instances of ScalarUDFs for each function
43make_udf_expr_and_func!(ArrayReplace,
44    array_replace,
45    array from to,
46    "replaces the first occurrence of the specified element with another specified element.",
47    array_replace_udf
48);
49make_udf_expr_and_func!(ArrayReplaceN,
50    array_replace_n,
51    array from to max,
52    "replaces the first `max` occurrences of the specified element with another specified element.",
53    array_replace_n_udf
54);
55make_udf_expr_and_func!(ArrayReplaceAll,
56    array_replace_all,
57    array from to,
58    "replaces all occurrences of the specified element with another specified element.",
59    array_replace_all_udf
60);
61
62#[user_doc(
63    doc_section(label = "Array Functions"),
64    description = "Replaces the first occurrence of the specified element with another specified element.",
65    syntax_example = "array_replace(array, from, to)",
66    sql_example = r#"```sql
67> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5);
68+--------------------------------------------------------+
69| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
70+--------------------------------------------------------+
71| [1, 5, 2, 3, 2, 1, 4]                                  |
72+--------------------------------------------------------+
73```"#,
74    argument(
75        name = "array",
76        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
77    ),
78    argument(name = "from", description = "Initial element."),
79    argument(name = "to", description = "Final element.")
80)]
81#[derive(Debug)]
82pub struct ArrayReplace {
83    signature: Signature,
84    aliases: Vec<String>,
85}
86
87impl Default for ArrayReplace {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl ArrayReplace {
94    pub fn new() -> Self {
95        Self {
96            signature: Signature {
97                type_signature: TypeSignature::ArraySignature(
98                    ArrayFunctionSignature::Array {
99                        arguments: vec![
100                            ArrayFunctionArgument::Array,
101                            ArrayFunctionArgument::Element,
102                            ArrayFunctionArgument::Element,
103                        ],
104                        array_coercion: Some(ListCoercion::FixedSizedListToList),
105                    },
106                ),
107                volatility: Volatility::Immutable,
108            },
109            aliases: vec![String::from("list_replace")],
110        }
111    }
112}
113
114impl ScalarUDFImpl for ArrayReplace {
115    fn as_any(&self) -> &dyn Any {
116        self
117    }
118
119    fn name(&self) -> &str {
120        "array_replace"
121    }
122
123    fn signature(&self) -> &Signature {
124        &self.signature
125    }
126
127    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
128        Ok(args[0].clone())
129    }
130
131    fn invoke_with_args(
132        &self,
133        args: datafusion_expr::ScalarFunctionArgs,
134    ) -> Result<ColumnarValue> {
135        make_scalar_function(array_replace_inner)(&args.args)
136    }
137
138    fn aliases(&self) -> &[String] {
139        &self.aliases
140    }
141
142    fn documentation(&self) -> Option<&Documentation> {
143        self.doc()
144    }
145}
146
147#[user_doc(
148    doc_section(label = "Array Functions"),
149    description = "Replaces the first `max` occurrences of the specified element with another specified element.",
150    syntax_example = "array_replace_n(array, from, to, max)",
151    sql_example = r#"```sql
152> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2);
153+-------------------------------------------------------------------+
154| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) |
155+-------------------------------------------------------------------+
156| [1, 5, 5, 3, 2, 1, 4]                                             |
157+-------------------------------------------------------------------+
158```"#,
159    argument(
160        name = "array",
161        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
162    ),
163    argument(name = "from", description = "Initial element."),
164    argument(name = "to", description = "Final element."),
165    argument(name = "max", description = "Number of first occurrences to replace.")
166)]
167#[derive(Debug)]
168pub(super) struct ArrayReplaceN {
169    signature: Signature,
170    aliases: Vec<String>,
171}
172
173impl ArrayReplaceN {
174    pub fn new() -> Self {
175        Self {
176            signature: Signature {
177                type_signature: TypeSignature::ArraySignature(
178                    ArrayFunctionSignature::Array {
179                        arguments: vec![
180                            ArrayFunctionArgument::Array,
181                            ArrayFunctionArgument::Element,
182                            ArrayFunctionArgument::Element,
183                            ArrayFunctionArgument::Index,
184                        ],
185                        array_coercion: Some(ListCoercion::FixedSizedListToList),
186                    },
187                ),
188                volatility: Volatility::Immutable,
189            },
190            aliases: vec![String::from("list_replace_n")],
191        }
192    }
193}
194
195impl ScalarUDFImpl for ArrayReplaceN {
196    fn as_any(&self) -> &dyn Any {
197        self
198    }
199
200    fn name(&self) -> &str {
201        "array_replace_n"
202    }
203
204    fn signature(&self) -> &Signature {
205        &self.signature
206    }
207
208    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
209        Ok(args[0].clone())
210    }
211
212    fn invoke_with_args(
213        &self,
214        args: datafusion_expr::ScalarFunctionArgs,
215    ) -> Result<ColumnarValue> {
216        make_scalar_function(array_replace_n_inner)(&args.args)
217    }
218
219    fn aliases(&self) -> &[String] {
220        &self.aliases
221    }
222
223    fn documentation(&self) -> Option<&Documentation> {
224        self.doc()
225    }
226}
227
228#[user_doc(
229    doc_section(label = "Array Functions"),
230    description = "Replaces all occurrences of the specified element with another specified element.",
231    syntax_example = "array_replace_all(array, from, to)",
232    sql_example = r#"```sql
233> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5);
234+------------------------------------------------------------+
235| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
236+------------------------------------------------------------+
237| [1, 5, 5, 3, 5, 1, 4]                                      |
238+------------------------------------------------------------+
239```"#,
240    argument(
241        name = "array",
242        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
243    ),
244    argument(name = "from", description = "Initial element."),
245    argument(name = "to", description = "Final element.")
246)]
247#[derive(Debug)]
248pub(super) struct ArrayReplaceAll {
249    signature: Signature,
250    aliases: Vec<String>,
251}
252
253impl ArrayReplaceAll {
254    pub fn new() -> Self {
255        Self {
256            signature: Signature {
257                type_signature: TypeSignature::ArraySignature(
258                    ArrayFunctionSignature::Array {
259                        arguments: vec![
260                            ArrayFunctionArgument::Array,
261                            ArrayFunctionArgument::Element,
262                            ArrayFunctionArgument::Element,
263                        ],
264                        array_coercion: Some(ListCoercion::FixedSizedListToList),
265                    },
266                ),
267                volatility: Volatility::Immutable,
268            },
269            aliases: vec![String::from("list_replace_all")],
270        }
271    }
272}
273
274impl ScalarUDFImpl for ArrayReplaceAll {
275    fn as_any(&self) -> &dyn Any {
276        self
277    }
278
279    fn name(&self) -> &str {
280        "array_replace_all"
281    }
282
283    fn signature(&self) -> &Signature {
284        &self.signature
285    }
286
287    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
288        Ok(args[0].clone())
289    }
290
291    fn invoke_with_args(
292        &self,
293        args: datafusion_expr::ScalarFunctionArgs,
294    ) -> Result<ColumnarValue> {
295        make_scalar_function(array_replace_all_inner)(&args.args)
296    }
297
298    fn aliases(&self) -> &[String] {
299        &self.aliases
300    }
301
302    fn documentation(&self) -> Option<&Documentation> {
303        self.doc()
304    }
305}
306
307/// For each element of `list_array[i]`, replaces up to `arr_n[i]`  occurrences
308/// of `from_array[i]`, `to_array[i]`.
309///
310/// The type of each **element** in `list_array` must be the same as the type of
311/// `from_array` and `to_array`. This function also handles nested arrays
312/// (\[`ListArray`\] of \[`ListArray`\]s)
313///
314/// For example, when called to replace a list array (where each element is a
315/// list of int32s, the second and third argument are int32 arrays, and the
316/// fourth argument is the number of occurrences to replace
317///
318/// ```text
319/// general_replace(
320///   [1, 2, 3, 2], 2, 10, 1    ==> [1, 10, 3, 2]   (only the first 2 is replaced)
321///   [4, 5, 6, 5], 5, 20, 2    ==> [4, 20, 6, 20]  (both 5s are replaced)
322/// )
323/// ```
324fn general_replace<O: OffsetSizeTrait>(
325    list_array: &GenericListArray<O>,
326    from_array: &ArrayRef,
327    to_array: &ArrayRef,
328    arr_n: Vec<i64>,
329) -> Result<ArrayRef> {
330    // Build up the offsets for the final output array
331    let mut offsets: Vec<O> = vec![O::usize_as(0)];
332    let values = list_array.values();
333    let original_data = values.to_data();
334    let to_data = to_array.to_data();
335    let capacity = Capacities::Array(original_data.len());
336
337    // First array is the original array, second array is the element to replace with.
338    let mut mutable = MutableArrayData::with_capacities(
339        vec![&original_data, &to_data],
340        false,
341        capacity,
342    );
343
344    let mut valid = NullBufferBuilder::new(list_array.len());
345
346    for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
347        if list_array.is_null(row_index) {
348            offsets.push(offsets[row_index]);
349            valid.append_null();
350            continue;
351        }
352
353        let start = offset_window[0];
354        let end = offset_window[1];
355
356        let list_array_row = list_array.value(row_index);
357
358        // Compute all positions in list_row_array (that is itself an
359        // array) that are equal to `from_array_row`
360        let eq_array =
361            compare_element_to_list(&list_array_row, &from_array, row_index, true)?;
362
363        let original_idx = O::usize_as(0);
364        let replace_idx = O::usize_as(1);
365        let n = arr_n[row_index];
366        let mut counter = 0;
367
368        // All elements are false, no need to replace, just copy original data
369        if eq_array.false_count() == eq_array.len() {
370            mutable.extend(
371                original_idx.to_usize().unwrap(),
372                start.to_usize().unwrap(),
373                end.to_usize().unwrap(),
374            );
375            offsets.push(offsets[row_index] + (end - start));
376            valid.append_non_null();
377            continue;
378        }
379
380        for (i, to_replace) in eq_array.iter().enumerate() {
381            let i = O::usize_as(i);
382            if let Some(true) = to_replace {
383                mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1);
384                counter += 1;
385                if counter == n {
386                    // copy original data for any matches past n
387                    mutable.extend(
388                        original_idx.to_usize().unwrap(),
389                        (start + i).to_usize().unwrap() + 1,
390                        end.to_usize().unwrap(),
391                    );
392                    break;
393                }
394            } else {
395                // copy original data for false / null matches
396                mutable.extend(
397                    original_idx.to_usize().unwrap(),
398                    (start + i).to_usize().unwrap(),
399                    (start + i).to_usize().unwrap() + 1,
400                );
401            }
402        }
403
404        offsets.push(offsets[row_index] + (end - start));
405        valid.append_non_null();
406    }
407
408    let data = mutable.freeze();
409
410    Ok(Arc::new(GenericListArray::<O>::try_new(
411        Arc::new(Field::new_list_field(list_array.value_type(), true)),
412        OffsetBuffer::<O>::new(offsets.into()),
413        arrow::array::make_array(data),
414        valid.finish(),
415    )?))
416}
417
418pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
419    let [array, from, to] = take_function_args("array_replace", args)?;
420
421    // replace at most one occurrence for each element
422    let arr_n = vec![1; array.len()];
423    match array.data_type() {
424        DataType::List(_) => {
425            let list_array = array.as_list::<i32>();
426            general_replace::<i32>(list_array, from, to, arr_n)
427        }
428        DataType::LargeList(_) => {
429            let list_array = array.as_list::<i64>();
430            general_replace::<i64>(list_array, from, to, arr_n)
431        }
432        DataType::Null => Ok(new_null_array(array.data_type(), 1)),
433        array_type => exec_err!("array_replace does not support type '{array_type:?}'."),
434    }
435}
436
437pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
438    let [array, from, to, max] = take_function_args("array_replace_n", args)?;
439
440    // replace the specified number of occurrences
441    let arr_n = as_int64_array(max)?.values().to_vec();
442    match array.data_type() {
443        DataType::List(_) => {
444            let list_array = array.as_list::<i32>();
445            general_replace::<i32>(list_array, from, to, arr_n)
446        }
447        DataType::LargeList(_) => {
448            let list_array = array.as_list::<i64>();
449            general_replace::<i64>(list_array, from, to, arr_n)
450        }
451        DataType::Null => Ok(new_null_array(array.data_type(), 1)),
452        array_type => {
453            exec_err!("array_replace_n does not support type '{array_type:?}'.")
454        }
455    }
456}
457
458pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
459    let [array, from, to] = take_function_args("array_replace_all", args)?;
460
461    // replace all occurrences (up to "i64::MAX")
462    let arr_n = vec![i64::MAX; array.len()];
463    match array.data_type() {
464        DataType::List(_) => {
465            let list_array = array.as_list::<i32>();
466            general_replace::<i32>(list_array, from, to, arr_n)
467        }
468        DataType::LargeList(_) => {
469            let list_array = array.as_list::<i64>();
470            general_replace::<i64>(list_array, from, to, arr_n)
471        }
472        DataType::Null => Ok(new_null_array(array.data_type(), 1)),
473        array_type => {
474            exec_err!("array_replace_all does not support type '{array_type:?}'.")
475        }
476    }
477}