Skip to main content

datafusion_functions/string/
replace.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
21use arrow::buffer::NullBuffer;
22use arrow::datatypes::DataType;
23
24use crate::strings::{
25    BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringWriter,
26};
27use crate::utils::{make_scalar_function, utf8_to_str_type};
28use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
29use datafusion_common::types::logical_string;
30use datafusion_common::{Result, exec_err};
31use datafusion_expr::type_coercion::binary::{
32    binary_to_string_coercion, string_coercion,
33};
34use datafusion_expr::{
35    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
36    TypeSignatureClass, Volatility,
37};
38use datafusion_macros::user_doc;
39#[user_doc(
40    doc_section(label = "String Functions"),
41    description = "Replaces all occurrences of a specified substring in a string with a new substring.",
42    syntax_example = "replace(str, substr, replacement)",
43    sql_example = r#"```sql
44> select replace('ABabbaBA', 'ab', 'cd');
45+-------------------------------------------------+
46| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) |
47+-------------------------------------------------+
48| ABcdbaBA                                        |
49+-------------------------------------------------+
50```"#,
51    standard_argument(name = "str", prefix = "String"),
52    standard_argument(
53        name = "substr",
54        prefix = "Substring expression to replace in the input string. Substring"
55    ),
56    standard_argument(name = "replacement", prefix = "Replacement substring")
57)]
58#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct ReplaceFunc {
60    signature: Signature,
61}
62
63impl Default for ReplaceFunc {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl ReplaceFunc {
70    pub fn new() -> Self {
71        Self {
72            signature: Signature::coercible(
73                vec![
74                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
75                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
76                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
77                ],
78                Volatility::Immutable,
79            ),
80        }
81    }
82}
83
84impl ScalarUDFImpl for ReplaceFunc {
85    fn name(&self) -> &str {
86        "replace"
87    }
88
89    fn signature(&self) -> &Signature {
90        &self.signature
91    }
92
93    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
94        if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1])
95            .and_then(|dt| string_coercion(&dt, &arg_types[2]))
96            .or_else(|| {
97                binary_to_string_coercion(&arg_types[0], &arg_types[1])
98                    .and_then(|dt| binary_to_string_coercion(&dt, &arg_types[2]))
99            })
100        {
101            utf8_to_str_type(&coercion_data_type, "replace")
102        } else {
103            exec_err!(
104                "Unsupported data types for replace. Expected Utf8, LargeUtf8 or Utf8View"
105            )
106        }
107    }
108
109    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110        let data_types = args
111            .args
112            .iter()
113            .map(|arg| arg.data_type())
114            .collect::<Vec<_>>();
115
116        if let Some(coercion_type) = string_coercion(&data_types[0], &data_types[1])
117            .and_then(|dt| string_coercion(&dt, &data_types[2]))
118            .or_else(|| {
119                binary_to_string_coercion(&data_types[0], &data_types[1])
120                    .and_then(|dt| binary_to_string_coercion(&dt, &data_types[2]))
121            })
122        {
123            let mut converted_args = Vec::with_capacity(args.args.len());
124            for arg in &args.args {
125                if arg.data_type() == coercion_type {
126                    converted_args.push(arg.clone());
127                } else {
128                    let converted = arg.cast_to(&coercion_type, None)?;
129                    converted_args.push(converted);
130                }
131            }
132
133            match coercion_type {
134                DataType::Utf8 => {
135                    make_scalar_function(replace::<i32>, vec![])(&converted_args)
136                }
137                DataType::LargeUtf8 => {
138                    make_scalar_function(replace::<i64>, vec![])(&converted_args)
139                }
140                DataType::Utf8View => {
141                    make_scalar_function(replace_view, vec![])(&converted_args)
142                }
143                other => exec_err!(
144                    "Unsupported coercion data type {other:?} for function replace"
145                ),
146            }
147        } else {
148            exec_err!(
149                "Unsupported data type {}, {:?}, {:?} for function replace.",
150                data_types[0],
151                data_types[1],
152                data_types[2]
153            )
154        }
155    }
156
157    fn documentation(&self) -> Option<&Documentation> {
158        self.doc()
159    }
160}
161
162fn replace_view(args: &[ArrayRef]) -> Result<ArrayRef> {
163    let string_array = as_string_view_array(&args[0])?;
164    let from_array = as_string_view_array(&args[1])?;
165    let to_array = as_string_view_array(&args[2])?;
166
167    let len = string_array.len();
168    let mut builder = GenericStringArrayBuilder::<i32>::with_capacity(len, 0);
169    let nulls = NullBuffer::union_many([
170        string_array.nulls(),
171        from_array.nulls(),
172        to_array.nulls(),
173    ]);
174
175    // Hoist the nulls.is_some() check out of the loop. LLVM does not always
176    // unswitch this loop on its own (the Utf8View body is large enough to
177    // exceed its cost-benefit threshold).
178    if let Some(nulls_ref) = nulls.as_ref() {
179        for i in 0..len {
180            if nulls_ref.is_null(i) {
181                builder.append_placeholder();
182                continue;
183            }
184            // SAFETY: union of input nulls is non-null at i, so each input is too.
185            let string = unsafe { string_array.value_unchecked(i) };
186            let from = unsafe { from_array.value_unchecked(i) };
187            let to = unsafe { to_array.value_unchecked(i) };
188            apply_replace(&mut builder, string, from, to);
189        }
190    } else {
191        for i in 0..len {
192            // SAFETY: i < len, and no input has a null buffer.
193            let string = unsafe { string_array.value_unchecked(i) };
194            let from = unsafe { from_array.value_unchecked(i) };
195            let to = unsafe { to_array.value_unchecked(i) };
196            apply_replace(&mut builder, string, from, to);
197        }
198    }
199
200    Ok(Arc::new(builder.finish(nulls)?) as ArrayRef)
201}
202
203/// Replaces all occurrences in string of substring from with substring to.
204/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef'
205fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
206    let string_array = as_generic_string_array::<T>(&args[0])?;
207    let from_array = as_generic_string_array::<T>(&args[1])?;
208    let to_array = as_generic_string_array::<T>(&args[2])?;
209
210    let len = string_array.len();
211    let mut builder = GenericStringArrayBuilder::<T>::with_capacity(len, 0);
212    let nulls = NullBuffer::union_many([
213        string_array.nulls(),
214        from_array.nulls(),
215        to_array.nulls(),
216    ]);
217
218    // Hoist the nulls.is_some() check out of the loop. LLVM unswitches this
219    // automatically today, but kept explicit so the no-nulls fast path is not
220    // contingent on the optimizer's cost heuristic.
221    if let Some(nulls_ref) = nulls.as_ref() {
222        for i in 0..len {
223            if nulls_ref.is_null(i) {
224                builder.append_placeholder();
225                continue;
226            }
227            // SAFETY: union of input nulls is non-null at i, so each input is too.
228            let string = unsafe { string_array.value_unchecked(i) };
229            let from = unsafe { from_array.value_unchecked(i) };
230            let to = unsafe { to_array.value_unchecked(i) };
231            apply_replace(&mut builder, string, from, to);
232        }
233    } else {
234        for i in 0..len {
235            // SAFETY: i < len, and no input has a null buffer.
236            let string = unsafe { string_array.value_unchecked(i) };
237            let from = unsafe { from_array.value_unchecked(i) };
238            let to = unsafe { to_array.value_unchecked(i) };
239            apply_replace(&mut builder, string, from, to);
240        }
241    }
242
243    Ok(Arc::new(builder.finish(nulls)?) as ArrayRef)
244}
245
246#[inline]
247fn apply_replace<B: BulkNullStringArrayBuilder>(
248    builder: &mut B,
249    string: &str,
250    from: &str,
251    to: &str,
252) {
253    // Hot path: single ASCII byte → single ASCII byte. An ASCII byte (< 0x80)
254    // cannot appear inside a multi-byte UTF-8 sequence, so any multi-byte
255    // sequences in `string` pass through unchanged and output stays valid
256    // UTF-8.
257    if let (&[from_byte], &[to_byte]) = (from.as_bytes(), to.as_bytes())
258        && from_byte.is_ascii()
259        && to_byte.is_ascii()
260    {
261        // SAFETY: see the contract above.
262        unsafe {
263            builder.append_byte_map(string.as_bytes(), |b| {
264                if b == from_byte { to_byte } else { b }
265            });
266        }
267        return;
268    }
269
270    if from.is_empty() {
271        // Empty `from`: insert `to` before each character and at both ends.
272        builder.append_with(|w| {
273            w.write_str(to);
274            for ch in string.chars() {
275                w.write_char(ch);
276                w.write_str(to);
277            }
278        });
279        return;
280    }
281
282    builder.append_with(|w| replace_into_writer(w, string, from, to));
283}
284
285#[inline]
286fn replace_into_writer<W: StringWriter>(w: &mut W, string: &str, from: &str, to: &str) {
287    let mut last_end = 0;
288    for (start, _part) in string.match_indices(from) {
289        w.write_str(&string[last_end..start]);
290        w.write_str(to);
291        last_end = start + from.len();
292    }
293    w.write_str(&string[last_end..]);
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::utils::test::test_function;
300    use arrow::array::LargeStringArray;
301    use arrow::array::StringArray;
302    use arrow::datatypes::DataType::{LargeUtf8, Utf8};
303    use datafusion_common::ScalarValue;
304    #[test]
305    fn test_functions() -> Result<()> {
306        test_function!(
307            ReplaceFunc::new(),
308            vec![
309                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))),
310                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))),
311                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))),
312            ],
313            Ok(Some("aacccdqcccc")),
314            &str,
315            Utf8,
316            StringArray
317        );
318
319        test_function!(
320            ReplaceFunc::new(),
321            vec![
322                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
323                    "aabbb"
324                )))),
325                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))),
326                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))),
327            ],
328            Ok(Some("aacc")),
329            &str,
330            LargeUtf8,
331            LargeStringArray
332        );
333
334        test_function!(
335            ReplaceFunc::new(),
336            vec![
337                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
338                    "aabbbcw"
339                )))),
340                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))),
341                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))),
342            ],
343            Ok(Some("aaccbcw")),
344            &str,
345            Utf8,
346            StringArray
347        );
348
349        Ok(())
350    }
351}