datafusion_functions/string/
replace.rs1use std::sync::Arc;
19
20use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
21use arrow::buffer::NullBuffer;
22use arrow::datatypes::DataType;
23
24use crate::strings::{
25 BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringWriter,
26};
27use crate::utils::{make_scalar_function, utf8_to_str_type};
28use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
29use datafusion_common::types::logical_string;
30use datafusion_common::{Result, exec_err};
31use datafusion_expr::type_coercion::binary::{
32 binary_to_string_coercion, string_coercion,
33};
34use datafusion_expr::{
35 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
36 TypeSignatureClass, Volatility,
37};
38use datafusion_macros::user_doc;
39#[user_doc(
40 doc_section(label = "String Functions"),
41 description = "Replaces all occurrences of a specified substring in a string with a new substring.",
42 syntax_example = "replace(str, substr, replacement)",
43 sql_example = r#"```sql
44> select replace('ABabbaBA', 'ab', 'cd');
45+-------------------------------------------------+
46| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) |
47+-------------------------------------------------+
48| ABcdbaBA |
49+-------------------------------------------------+
50```"#,
51 standard_argument(name = "str", prefix = "String"),
52 standard_argument(
53 name = "substr",
54 prefix = "Substring expression to replace in the input string. Substring"
55 ),
56 standard_argument(name = "replacement", prefix = "Replacement substring")
57)]
58#[derive(Debug, PartialEq, Eq, Hash)]
59pub struct ReplaceFunc {
60 signature: Signature,
61}
62
63impl Default for ReplaceFunc {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl ReplaceFunc {
70 pub fn new() -> Self {
71 Self {
72 signature: Signature::coercible(
73 vec![
74 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
75 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
76 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
77 ],
78 Volatility::Immutable,
79 ),
80 }
81 }
82}
83
84impl ScalarUDFImpl for ReplaceFunc {
85 fn name(&self) -> &str {
86 "replace"
87 }
88
89 fn signature(&self) -> &Signature {
90 &self.signature
91 }
92
93 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
94 if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1])
95 .and_then(|dt| string_coercion(&dt, &arg_types[2]))
96 .or_else(|| {
97 binary_to_string_coercion(&arg_types[0], &arg_types[1])
98 .and_then(|dt| binary_to_string_coercion(&dt, &arg_types[2]))
99 })
100 {
101 utf8_to_str_type(&coercion_data_type, "replace")
102 } else {
103 exec_err!(
104 "Unsupported data types for replace. Expected Utf8, LargeUtf8 or Utf8View"
105 )
106 }
107 }
108
109 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110 let data_types = args
111 .args
112 .iter()
113 .map(|arg| arg.data_type())
114 .collect::<Vec<_>>();
115
116 if let Some(coercion_type) = string_coercion(&data_types[0], &data_types[1])
117 .and_then(|dt| string_coercion(&dt, &data_types[2]))
118 .or_else(|| {
119 binary_to_string_coercion(&data_types[0], &data_types[1])
120 .and_then(|dt| binary_to_string_coercion(&dt, &data_types[2]))
121 })
122 {
123 let mut converted_args = Vec::with_capacity(args.args.len());
124 for arg in &args.args {
125 if arg.data_type() == coercion_type {
126 converted_args.push(arg.clone());
127 } else {
128 let converted = arg.cast_to(&coercion_type, None)?;
129 converted_args.push(converted);
130 }
131 }
132
133 match coercion_type {
134 DataType::Utf8 => {
135 make_scalar_function(replace::<i32>, vec![])(&converted_args)
136 }
137 DataType::LargeUtf8 => {
138 make_scalar_function(replace::<i64>, vec![])(&converted_args)
139 }
140 DataType::Utf8View => {
141 make_scalar_function(replace_view, vec![])(&converted_args)
142 }
143 other => exec_err!(
144 "Unsupported coercion data type {other:?} for function replace"
145 ),
146 }
147 } else {
148 exec_err!(
149 "Unsupported data type {}, {:?}, {:?} for function replace.",
150 data_types[0],
151 data_types[1],
152 data_types[2]
153 )
154 }
155 }
156
157 fn documentation(&self) -> Option<&Documentation> {
158 self.doc()
159 }
160}
161
162fn replace_view(args: &[ArrayRef]) -> Result<ArrayRef> {
163 let string_array = as_string_view_array(&args[0])?;
164 let from_array = as_string_view_array(&args[1])?;
165 let to_array = as_string_view_array(&args[2])?;
166
167 let len = string_array.len();
168 let mut builder = GenericStringArrayBuilder::<i32>::with_capacity(len, 0);
169 let nulls = NullBuffer::union_many([
170 string_array.nulls(),
171 from_array.nulls(),
172 to_array.nulls(),
173 ]);
174
175 if let Some(nulls_ref) = nulls.as_ref() {
179 for i in 0..len {
180 if nulls_ref.is_null(i) {
181 builder.append_placeholder();
182 continue;
183 }
184 let string = unsafe { string_array.value_unchecked(i) };
186 let from = unsafe { from_array.value_unchecked(i) };
187 let to = unsafe { to_array.value_unchecked(i) };
188 apply_replace(&mut builder, string, from, to);
189 }
190 } else {
191 for i in 0..len {
192 let string = unsafe { string_array.value_unchecked(i) };
194 let from = unsafe { from_array.value_unchecked(i) };
195 let to = unsafe { to_array.value_unchecked(i) };
196 apply_replace(&mut builder, string, from, to);
197 }
198 }
199
200 Ok(Arc::new(builder.finish(nulls)?) as ArrayRef)
201}
202
203fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
206 let string_array = as_generic_string_array::<T>(&args[0])?;
207 let from_array = as_generic_string_array::<T>(&args[1])?;
208 let to_array = as_generic_string_array::<T>(&args[2])?;
209
210 let len = string_array.len();
211 let mut builder = GenericStringArrayBuilder::<T>::with_capacity(len, 0);
212 let nulls = NullBuffer::union_many([
213 string_array.nulls(),
214 from_array.nulls(),
215 to_array.nulls(),
216 ]);
217
218 if let Some(nulls_ref) = nulls.as_ref() {
222 for i in 0..len {
223 if nulls_ref.is_null(i) {
224 builder.append_placeholder();
225 continue;
226 }
227 let string = unsafe { string_array.value_unchecked(i) };
229 let from = unsafe { from_array.value_unchecked(i) };
230 let to = unsafe { to_array.value_unchecked(i) };
231 apply_replace(&mut builder, string, from, to);
232 }
233 } else {
234 for i in 0..len {
235 let string = unsafe { string_array.value_unchecked(i) };
237 let from = unsafe { from_array.value_unchecked(i) };
238 let to = unsafe { to_array.value_unchecked(i) };
239 apply_replace(&mut builder, string, from, to);
240 }
241 }
242
243 Ok(Arc::new(builder.finish(nulls)?) as ArrayRef)
244}
245
246#[inline]
247fn apply_replace<B: BulkNullStringArrayBuilder>(
248 builder: &mut B,
249 string: &str,
250 from: &str,
251 to: &str,
252) {
253 if let (&[from_byte], &[to_byte]) = (from.as_bytes(), to.as_bytes())
258 && from_byte.is_ascii()
259 && to_byte.is_ascii()
260 {
261 unsafe {
263 builder.append_byte_map(string.as_bytes(), |b| {
264 if b == from_byte { to_byte } else { b }
265 });
266 }
267 return;
268 }
269
270 if from.is_empty() {
271 builder.append_with(|w| {
273 w.write_str(to);
274 for ch in string.chars() {
275 w.write_char(ch);
276 w.write_str(to);
277 }
278 });
279 return;
280 }
281
282 builder.append_with(|w| replace_into_writer(w, string, from, to));
283}
284
285#[inline]
286fn replace_into_writer<W: StringWriter>(w: &mut W, string: &str, from: &str, to: &str) {
287 let mut last_end = 0;
288 for (start, _part) in string.match_indices(from) {
289 w.write_str(&string[last_end..start]);
290 w.write_str(to);
291 last_end = start + from.len();
292 }
293 w.write_str(&string[last_end..]);
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::utils::test::test_function;
300 use arrow::array::LargeStringArray;
301 use arrow::array::StringArray;
302 use arrow::datatypes::DataType::{LargeUtf8, Utf8};
303 use datafusion_common::ScalarValue;
304 #[test]
305 fn test_functions() -> Result<()> {
306 test_function!(
307 ReplaceFunc::new(),
308 vec![
309 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))),
310 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))),
311 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))),
312 ],
313 Ok(Some("aacccdqcccc")),
314 &str,
315 Utf8,
316 StringArray
317 );
318
319 test_function!(
320 ReplaceFunc::new(),
321 vec![
322 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
323 "aabbb"
324 )))),
325 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))),
326 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))),
327 ],
328 Ok(Some("aacc")),
329 &str,
330 LargeUtf8,
331 LargeStringArray
332 );
333
334 test_function!(
335 ReplaceFunc::new(),
336 vec![
337 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
338 "aabbbcw"
339 )))),
340 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))),
341 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))),
342 ],
343 Ok(Some("aaccbcw")),
344 &str,
345 Utf8,
346 StringArray
347 );
348
349 Ok(())
350 }
351}