datafusion_functions/string/
replace.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, GenericStringBuilder, OffsetSizeTrait};
22use arrow::datatypes::DataType;
23
24use crate::utils::{make_scalar_function, utf8_to_str_type};
25use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
26use datafusion_common::types::logical_string;
27use datafusion_common::{Result, exec_err};
28use datafusion_expr::type_coercion::binary::{
29    binary_to_string_coercion, string_coercion,
30};
31use datafusion_expr::{
32    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33    TypeSignatureClass, Volatility,
34};
35use datafusion_macros::user_doc;
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Replaces all occurrences of a specified substring in a string with a new substring.",
39    syntax_example = "replace(str, substr, replacement)",
40    sql_example = r#"```sql
41> select replace('ABabbaBA', 'ab', 'cd');
42+-------------------------------------------------+
43| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) |
44+-------------------------------------------------+
45| ABcdbaBA                                        |
46+-------------------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    standard_argument(
50        name = "substr",
51        prefix = "Substring expression to replace in the input string. Substring"
52    ),
53    standard_argument(name = "replacement", prefix = "Replacement substring")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct ReplaceFunc {
57    signature: Signature,
58}
59
60impl Default for ReplaceFunc {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl ReplaceFunc {
67    pub fn new() -> Self {
68        Self {
69            signature: Signature::coercible(
70                vec![
71                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
74                ],
75                Volatility::Immutable,
76            ),
77        }
78    }
79}
80
81impl ScalarUDFImpl for ReplaceFunc {
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn name(&self) -> &str {
87        "replace"
88    }
89
90    fn signature(&self) -> &Signature {
91        &self.signature
92    }
93
94    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95        if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1])
96            .and_then(|dt| string_coercion(&dt, &arg_types[2]))
97            .or_else(|| {
98                binary_to_string_coercion(&arg_types[0], &arg_types[1])
99                    .and_then(|dt| binary_to_string_coercion(&dt, &arg_types[2]))
100            })
101        {
102            utf8_to_str_type(&coercion_data_type, "replace")
103        } else {
104            exec_err!(
105                "Unsupported data types for replace. Expected Utf8, LargeUtf8 or Utf8View"
106            )
107        }
108    }
109
110    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
111        let data_types = args
112            .args
113            .iter()
114            .map(|arg| arg.data_type())
115            .collect::<Vec<_>>();
116
117        if let Some(coercion_type) = string_coercion(&data_types[0], &data_types[1])
118            .and_then(|dt| string_coercion(&dt, &data_types[2]))
119            .or_else(|| {
120                binary_to_string_coercion(&data_types[0], &data_types[1])
121                    .and_then(|dt| binary_to_string_coercion(&dt, &data_types[2]))
122            })
123        {
124            let mut converted_args = Vec::with_capacity(args.args.len());
125            for arg in &args.args {
126                if arg.data_type() == coercion_type {
127                    converted_args.push(arg.clone());
128                } else {
129                    let converted = arg.cast_to(&coercion_type, None)?;
130                    converted_args.push(converted);
131                }
132            }
133
134            match coercion_type {
135                DataType::Utf8 => {
136                    make_scalar_function(replace::<i32>, vec![])(&converted_args)
137                }
138                DataType::LargeUtf8 => {
139                    make_scalar_function(replace::<i64>, vec![])(&converted_args)
140                }
141                DataType::Utf8View => {
142                    make_scalar_function(replace_view, vec![])(&converted_args)
143                }
144                other => exec_err!(
145                    "Unsupported coercion data type {other:?} for function replace"
146                ),
147            }
148        } else {
149            exec_err!(
150                "Unsupported data type {}, {:?}, {:?} for function replace.",
151                data_types[0],
152                data_types[1],
153                data_types[2]
154            )
155        }
156    }
157
158    fn documentation(&self) -> Option<&Documentation> {
159        self.doc()
160    }
161}
162
163fn replace_view(args: &[ArrayRef]) -> Result<ArrayRef> {
164    let string_array = as_string_view_array(&args[0])?;
165    let from_array = as_string_view_array(&args[1])?;
166    let to_array = as_string_view_array(&args[2])?;
167
168    let mut builder = GenericStringBuilder::<i32>::new();
169    let mut buffer = String::new();
170
171    for ((string, from), to) in string_array
172        .iter()
173        .zip(from_array.iter())
174        .zip(to_array.iter())
175    {
176        match (string, from, to) {
177            (Some(string), Some(from), Some(to)) => {
178                buffer.clear();
179                replace_into_string(&mut buffer, string, from, to);
180                builder.append_value(&buffer);
181            }
182            _ => builder.append_null(),
183        }
184    }
185
186    Ok(Arc::new(builder.finish()) as ArrayRef)
187}
188
189/// Replaces all occurrences in string of substring from with substring to.
190/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef'
191fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
192    let string_array = as_generic_string_array::<T>(&args[0])?;
193    let from_array = as_generic_string_array::<T>(&args[1])?;
194    let to_array = as_generic_string_array::<T>(&args[2])?;
195
196    let mut builder = GenericStringBuilder::<T>::new();
197    let mut buffer = String::new();
198
199    for ((string, from), to) in string_array
200        .iter()
201        .zip(from_array.iter())
202        .zip(to_array.iter())
203    {
204        match (string, from, to) {
205            (Some(string), Some(from), Some(to)) => {
206                buffer.clear();
207                replace_into_string(&mut buffer, string, from, to);
208                builder.append_value(&buffer);
209            }
210            _ => builder.append_null(),
211        }
212    }
213
214    Ok(Arc::new(builder.finish()) as ArrayRef)
215}
216
217/// Helper function to perform string replacement into a reusable String buffer
218#[inline]
219fn replace_into_string(buffer: &mut String, string: &str, from: &str, to: &str) {
220    if from.is_empty() {
221        // When from is empty, insert 'to' at the beginning, between each character, and at the end
222        // This matches the behavior of str::replace()
223        buffer.push_str(to);
224        for ch in string.chars() {
225            buffer.push(ch);
226            buffer.push_str(to);
227        }
228        return;
229    }
230
231    // Fast path for replacing a single ASCII character with another single ASCII character
232    // This matches Rust's str::replace() optimization and enables vectorization
233    if let ([from_byte], [to_byte]) = (from.as_bytes(), to.as_bytes())
234        && from_byte.is_ascii()
235        && to_byte.is_ascii()
236    {
237        // SAFETY: We're replacing ASCII with ASCII, which preserves UTF-8 validity
238        let replaced: Vec<u8> = string
239            .as_bytes()
240            .iter()
241            .map(|b| if *b == *from_byte { *to_byte } else { *b })
242            .collect();
243        buffer.push_str(unsafe { std::str::from_utf8_unchecked(&replaced) });
244        return;
245    }
246
247    let mut last_end = 0;
248    for (start, _part) in string.match_indices(from) {
249        buffer.push_str(&string[last_end..start]);
250        buffer.push_str(to);
251        last_end = start + from.len();
252    }
253    buffer.push_str(&string[last_end..]);
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::utils::test::test_function;
260    use arrow::array::Array;
261    use arrow::array::LargeStringArray;
262    use arrow::array::StringArray;
263    use arrow::datatypes::DataType::{LargeUtf8, Utf8};
264    use datafusion_common::ScalarValue;
265    #[test]
266    fn test_functions() -> Result<()> {
267        test_function!(
268            ReplaceFunc::new(),
269            vec![
270                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))),
271                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))),
272                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))),
273            ],
274            Ok(Some("aacccdqcccc")),
275            &str,
276            Utf8,
277            StringArray
278        );
279
280        test_function!(
281            ReplaceFunc::new(),
282            vec![
283                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
284                    "aabbb"
285                )))),
286                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))),
287                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))),
288            ],
289            Ok(Some("aacc")),
290            &str,
291            LargeUtf8,
292            LargeStringArray
293        );
294
295        test_function!(
296            ReplaceFunc::new(),
297            vec![
298                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
299                    "aabbbcw"
300                )))),
301                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))),
302                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))),
303            ],
304            Ok(Some("aaccbcw")),
305            &str,
306            Utf8,
307            StringArray
308        );
309
310        Ok(())
311    }
312}