Skip to main content

datafusion_functions_nested/
replace.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_replace, array_replace_n and array_replace_all functions.
19
20use arrow::array::{
21    Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData,
22    NullBufferBuilder, OffsetSizeTrait, new_null_array,
23};
24use arrow::datatypes::{DataType, Field};
25
26use arrow::buffer::OffsetBuffer;
27use datafusion_common::cast::as_int64_array;
28use datafusion_common::utils::ListCoercion;
29use datafusion_common::{Result, exec_err, utils::take_function_args};
30use datafusion_expr::{
31    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
32    ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36use crate::utils::compare_element_to_list;
37use crate::utils::make_scalar_function;
38
39use std::sync::Arc;
40
41// Create static instances of ScalarUDFs for each function
42make_udf_expr_and_func!(ArrayReplace,
43    array_replace,
44    array from to,
45    "replaces the first occurrence of the specified element with another specified element.",
46    array_replace_udf
47);
48make_udf_expr_and_func!(ArrayReplaceN,
49    array_replace_n,
50    array from to max,
51    "replaces the first `max` occurrences of the specified element with another specified element.",
52    array_replace_n_udf
53);
54make_udf_expr_and_func!(ArrayReplaceAll,
55    array_replace_all,
56    array from to,
57    "replaces all occurrences of the specified element with another specified element.",
58    array_replace_all_udf
59);
60
61#[user_doc(
62    doc_section(label = "Array Functions"),
63    description = "Replaces the first occurrence of the specified element with another specified element.",
64    syntax_example = "array_replace(array, from, to)",
65    sql_example = r#"```sql
66> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5);
67+--------------------------------------------------------+
68| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
69+--------------------------------------------------------+
70| [1, 5, 2, 3, 2, 1, 4]                                  |
71+--------------------------------------------------------+
72```"#,
73    argument(
74        name = "array",
75        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
76    ),
77    argument(name = "from", description = "Initial element."),
78    argument(name = "to", description = "Final element.")
79)]
80#[derive(Debug, PartialEq, Eq, Hash)]
81pub struct ArrayReplace {
82    signature: Signature,
83    aliases: Vec<String>,
84}
85
86impl Default for ArrayReplace {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl ArrayReplace {
93    pub fn new() -> Self {
94        Self {
95            signature: Signature {
96                type_signature: TypeSignature::ArraySignature(
97                    ArrayFunctionSignature::Array {
98                        arguments: vec![
99                            ArrayFunctionArgument::Array,
100                            ArrayFunctionArgument::Element,
101                            ArrayFunctionArgument::Element,
102                        ],
103                        array_coercion: Some(ListCoercion::FixedSizedListToList),
104                    },
105                ),
106                volatility: Volatility::Immutable,
107                parameter_names: None,
108            },
109            aliases: vec![String::from("list_replace")],
110        }
111    }
112}
113
114impl ScalarUDFImpl for ArrayReplace {
115    fn name(&self) -> &str {
116        "array_replace"
117    }
118
119    fn signature(&self) -> &Signature {
120        &self.signature
121    }
122
123    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
124        Ok(args[0].clone())
125    }
126
127    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
128        make_scalar_function(array_replace_inner)(&args.args)
129    }
130
131    fn aliases(&self) -> &[String] {
132        &self.aliases
133    }
134
135    fn documentation(&self) -> Option<&Documentation> {
136        self.doc()
137    }
138}
139
140#[user_doc(
141    doc_section(label = "Array Functions"),
142    description = "Replaces the first `max` occurrences of the specified element with another specified element.",
143    syntax_example = "array_replace_n(array, from, to, max)",
144    sql_example = r#"```sql
145> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2);
146+-------------------------------------------------------------------+
147| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) |
148+-------------------------------------------------------------------+
149| [1, 5, 5, 3, 2, 1, 4]                                             |
150+-------------------------------------------------------------------+
151```"#,
152    argument(
153        name = "array",
154        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
155    ),
156    argument(name = "from", description = "Initial element."),
157    argument(name = "to", description = "Final element."),
158    argument(name = "max", description = "Number of first occurrences to replace.")
159)]
160#[derive(Debug, PartialEq, Eq, Hash)]
161pub(super) struct ArrayReplaceN {
162    signature: Signature,
163    aliases: Vec<String>,
164}
165
166impl ArrayReplaceN {
167    pub fn new() -> Self {
168        Self {
169            signature: Signature {
170                type_signature: TypeSignature::ArraySignature(
171                    ArrayFunctionSignature::Array {
172                        arguments: vec![
173                            ArrayFunctionArgument::Array,
174                            ArrayFunctionArgument::Element,
175                            ArrayFunctionArgument::Element,
176                            ArrayFunctionArgument::Index,
177                        ],
178                        array_coercion: Some(ListCoercion::FixedSizedListToList),
179                    },
180                ),
181                volatility: Volatility::Immutable,
182                parameter_names: None,
183            },
184            aliases: vec![String::from("list_replace_n")],
185        }
186    }
187}
188
189impl ScalarUDFImpl for ArrayReplaceN {
190    fn name(&self) -> &str {
191        "array_replace_n"
192    }
193
194    fn signature(&self) -> &Signature {
195        &self.signature
196    }
197
198    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
199        Ok(args[0].clone())
200    }
201
202    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
203        make_scalar_function(array_replace_n_inner)(&args.args)
204    }
205
206    fn aliases(&self) -> &[String] {
207        &self.aliases
208    }
209
210    fn documentation(&self) -> Option<&Documentation> {
211        self.doc()
212    }
213}
214
215#[user_doc(
216    doc_section(label = "Array Functions"),
217    description = "Replaces all occurrences of the specified element with another specified element.",
218    syntax_example = "array_replace_all(array, from, to)",
219    sql_example = r#"```sql
220> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5);
221+------------------------------------------------------------+
222| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) |
223+------------------------------------------------------------+
224| [1, 5, 5, 3, 5, 1, 4]                                      |
225+------------------------------------------------------------+
226```"#,
227    argument(
228        name = "array",
229        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
230    ),
231    argument(name = "from", description = "Initial element."),
232    argument(name = "to", description = "Final element.")
233)]
234#[derive(Debug, PartialEq, Eq, Hash)]
235pub(super) struct ArrayReplaceAll {
236    signature: Signature,
237    aliases: Vec<String>,
238}
239
240impl ArrayReplaceAll {
241    pub fn new() -> Self {
242        Self {
243            signature: Signature {
244                type_signature: TypeSignature::ArraySignature(
245                    ArrayFunctionSignature::Array {
246                        arguments: vec![
247                            ArrayFunctionArgument::Array,
248                            ArrayFunctionArgument::Element,
249                            ArrayFunctionArgument::Element,
250                        ],
251                        array_coercion: Some(ListCoercion::FixedSizedListToList),
252                    },
253                ),
254                volatility: Volatility::Immutable,
255                parameter_names: None,
256            },
257            aliases: vec![String::from("list_replace_all")],
258        }
259    }
260}
261
262impl ScalarUDFImpl for ArrayReplaceAll {
263    fn name(&self) -> &str {
264        "array_replace_all"
265    }
266
267    fn signature(&self) -> &Signature {
268        &self.signature
269    }
270
271    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
272        Ok(args[0].clone())
273    }
274
275    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
276        make_scalar_function(array_replace_all_inner)(&args.args)
277    }
278
279    fn aliases(&self) -> &[String] {
280        &self.aliases
281    }
282
283    fn documentation(&self) -> Option<&Documentation> {
284        self.doc()
285    }
286}
287
288/// For each element of `list_array[i]`, replaces up to `arr_n[i]`  occurrences
289/// of `from_array[i]`, `to_array[i]`.
290///
291/// The type of each **element** in `list_array` must be the same as the type of
292/// `from_array` and `to_array`. This function also handles nested arrays
293/// (\[`ListArray`\] of \[`ListArray`\]s)
294///
295/// For example, when called to replace a list array (where each element is a
296/// list of int32s, the second and third argument are int32 arrays, and the
297/// fourth argument is the number of occurrences to replace
298///
299/// ```text
300/// general_replace(
301///   [1, 2, 3, 2], 2, 10, 1    ==> [1, 10, 3, 2]   (only the first 2 is replaced)
302///   [4, 5, 6, 5], 5, 20, 2    ==> [4, 20, 6, 20]  (both 5s are replaced)
303/// )
304/// ```
305fn general_replace<O: OffsetSizeTrait>(
306    list_array: &GenericListArray<O>,
307    from_array: &ArrayRef,
308    to_array: &ArrayRef,
309    arr_n: &[i64],
310) -> Result<ArrayRef> {
311    // Build up the offsets for the final output array
312    let mut offsets: Vec<O> = vec![O::usize_as(0)];
313    let values = list_array.values();
314    let original_data = values.to_data();
315    let to_data = to_array.to_data();
316    let capacity = Capacities::Array(original_data.len());
317
318    // First array is the original array, second array is the element to replace with.
319    let mut mutable = MutableArrayData::with_capacities(
320        vec![&original_data, &to_data],
321        false,
322        capacity,
323    );
324
325    let mut valid = NullBufferBuilder::new(list_array.len());
326
327    for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
328        if list_array.is_null(row_index) {
329            offsets.push(offsets[row_index]);
330            valid.append_null();
331            continue;
332        }
333
334        let start = offset_window[0];
335        let end = offset_window[1];
336
337        let list_array_row = list_array.value(row_index);
338
339        // Compute all positions in list_row_array (that is itself an
340        // array) that are equal to `from_array_row`
341        let eq_array =
342            compare_element_to_list(&list_array_row, &from_array, row_index, true)?;
343
344        let original_idx = O::usize_as(0);
345        let replace_idx = O::usize_as(1);
346        let n = arr_n[row_index];
347        let mut counter = 0;
348
349        // All elements are false, no need to replace, just copy original data
350        if n <= 0 || !eq_array.has_true() {
351            mutable.extend(
352                original_idx.to_usize().unwrap(),
353                start.to_usize().unwrap(),
354                end.to_usize().unwrap(),
355            );
356            offsets.push(offsets[row_index] + (end - start));
357            valid.append_non_null();
358            continue;
359        }
360
361        let mut pending_retain: Option<O> = None;
362        for (i, to_replace) in eq_array.iter().enumerate() {
363            let i = O::usize_as(i);
364            if to_replace == Some(true) && counter < n {
365                // Flush any pending retain run before emitting the replacement.
366                if let Some(rs) = pending_retain.take() {
367                    mutable.extend(
368                        original_idx.to_usize().unwrap(),
369                        (start + rs).to_usize().unwrap(),
370                        (start + i).to_usize().unwrap(),
371                    );
372                }
373                mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1);
374                counter += 1;
375                if counter == n {
376                    // copy original data for any matches past n
377                    mutable.extend(
378                        original_idx.to_usize().unwrap(),
379                        (start + i).to_usize().unwrap() + 1,
380                        end.to_usize().unwrap(),
381                    );
382                    break;
383                }
384            } else if pending_retain.is_none() {
385                pending_retain = Some(i);
386            }
387        }
388
389        // Flush trailing retain run when we exited the loop without ever
390        // hitting `counter == n` (i.e. fewer than `n` matches in this row).
391        if counter < n
392            && let Some(rs) = pending_retain
393        {
394            mutable.extend(
395                original_idx.to_usize().unwrap(),
396                (start + rs).to_usize().unwrap(),
397                end.to_usize().unwrap(),
398            );
399        }
400
401        offsets.push(offsets[row_index] + (end - start));
402        valid.append_non_null();
403    }
404
405    let data = mutable.freeze();
406
407    Ok(Arc::new(GenericListArray::<O>::try_new(
408        Arc::new(Field::new_list_field(list_array.value_type(), true)),
409        OffsetBuffer::<O>::new(offsets.into()),
410        arrow::array::make_array(data),
411        valid.finish(),
412    )?))
413}
414
415fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
416    let [array, from, to] = take_function_args("array_replace", args)?;
417
418    // replace at most one occurrence for each element
419    let arr_n = vec![1; array.len()];
420    match array.data_type() {
421        DataType::List(_) => {
422            let list_array = array.as_list::<i32>();
423            general_replace::<i32>(list_array, from, to, &arr_n)
424        }
425        DataType::LargeList(_) => {
426            let list_array = array.as_list::<i64>();
427            general_replace::<i64>(list_array, from, to, &arr_n)
428        }
429        DataType::Null => Ok(new_null_array(array.data_type(), 1)),
430        array_type => exec_err!("array_replace does not support type '{array_type}'."),
431    }
432}
433
434fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
435    let [array, from, to, max] = take_function_args("array_replace_n", args)?;
436
437    // replace the specified number of occurrences
438    let arr_n = as_int64_array(max)?.values().to_vec();
439    match array.data_type() {
440        DataType::List(_) => {
441            let list_array = array.as_list::<i32>();
442            general_replace::<i32>(list_array, from, to, &arr_n)
443        }
444        DataType::LargeList(_) => {
445            let list_array = array.as_list::<i64>();
446            general_replace::<i64>(list_array, from, to, &arr_n)
447        }
448        DataType::Null => Ok(new_null_array(array.data_type(), 1)),
449        array_type => {
450            exec_err!("array_replace_n does not support type '{array_type}'.")
451        }
452    }
453}
454
455fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
456    let [array, from, to] = take_function_args("array_replace_all", args)?;
457
458    // replace all occurrences (up to "i64::MAX")
459    let arr_n = vec![i64::MAX; array.len()];
460    match array.data_type() {
461        DataType::List(_) => {
462            let list_array = array.as_list::<i32>();
463            general_replace::<i32>(list_array, from, to, &arr_n)
464        }
465        DataType::LargeList(_) => {
466            let list_array = array.as_list::<i64>();
467            general_replace::<i64>(list_array, from, to, &arr_n)
468        }
469        DataType::Null => Ok(new_null_array(array.data_type(), 1)),
470        array_type => {
471            exec_err!("array_replace_all does not support type '{array_type}'.")
472        }
473    }
474}