datafusion_functions/string/
replace.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, GenericStringBuilder, OffsetSizeTrait};
22use arrow::datatypes::DataType;
23
24use crate::utils::{make_scalar_function, utf8_to_str_type};
25use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
26use datafusion_common::types::logical_string;
27use datafusion_common::{Result, exec_err};
28use datafusion_expr::type_coercion::binary::{
29 binary_to_string_coercion, string_coercion,
30};
31use datafusion_expr::{
32 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33 TypeSignatureClass, Volatility,
34};
35use datafusion_macros::user_doc;
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Replaces all occurrences of a specified substring in a string with a new substring.",
39 syntax_example = "replace(str, substr, replacement)",
40 sql_example = r#"```sql
41> select replace('ABabbaBA', 'ab', 'cd');
42+-------------------------------------------------+
43| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) |
44+-------------------------------------------------+
45| ABcdbaBA |
46+-------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 standard_argument(
50 name = "substr",
51 prefix = "Substring expression to replace in the input string. Substring"
52 ),
53 standard_argument(name = "replacement", prefix = "Replacement substring")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct ReplaceFunc {
57 signature: Signature,
58}
59
60impl Default for ReplaceFunc {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl ReplaceFunc {
67 pub fn new() -> Self {
68 Self {
69 signature: Signature::coercible(
70 vec![
71 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
74 ],
75 Volatility::Immutable,
76 ),
77 }
78 }
79}
80
81impl ScalarUDFImpl for ReplaceFunc {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn name(&self) -> &str {
87 "replace"
88 }
89
90 fn signature(&self) -> &Signature {
91 &self.signature
92 }
93
94 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95 if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1])
96 .and_then(|dt| string_coercion(&dt, &arg_types[2]))
97 .or_else(|| {
98 binary_to_string_coercion(&arg_types[0], &arg_types[1])
99 .and_then(|dt| binary_to_string_coercion(&dt, &arg_types[2]))
100 })
101 {
102 utf8_to_str_type(&coercion_data_type, "replace")
103 } else {
104 exec_err!(
105 "Unsupported data types for replace. Expected Utf8, LargeUtf8 or Utf8View"
106 )
107 }
108 }
109
110 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
111 let data_types = args
112 .args
113 .iter()
114 .map(|arg| arg.data_type())
115 .collect::<Vec<_>>();
116
117 if let Some(coercion_type) = string_coercion(&data_types[0], &data_types[1])
118 .and_then(|dt| string_coercion(&dt, &data_types[2]))
119 .or_else(|| {
120 binary_to_string_coercion(&data_types[0], &data_types[1])
121 .and_then(|dt| binary_to_string_coercion(&dt, &data_types[2]))
122 })
123 {
124 let mut converted_args = Vec::with_capacity(args.args.len());
125 for arg in &args.args {
126 if arg.data_type() == coercion_type {
127 converted_args.push(arg.clone());
128 } else {
129 let converted = arg.cast_to(&coercion_type, None)?;
130 converted_args.push(converted);
131 }
132 }
133
134 match coercion_type {
135 DataType::Utf8 => {
136 make_scalar_function(replace::<i32>, vec![])(&converted_args)
137 }
138 DataType::LargeUtf8 => {
139 make_scalar_function(replace::<i64>, vec![])(&converted_args)
140 }
141 DataType::Utf8View => {
142 make_scalar_function(replace_view, vec![])(&converted_args)
143 }
144 other => exec_err!(
145 "Unsupported coercion data type {other:?} for function replace"
146 ),
147 }
148 } else {
149 exec_err!(
150 "Unsupported data type {}, {:?}, {:?} for function replace.",
151 data_types[0],
152 data_types[1],
153 data_types[2]
154 )
155 }
156 }
157
158 fn documentation(&self) -> Option<&Documentation> {
159 self.doc()
160 }
161}
162
163fn replace_view(args: &[ArrayRef]) -> Result<ArrayRef> {
164 let string_array = as_string_view_array(&args[0])?;
165 let from_array = as_string_view_array(&args[1])?;
166 let to_array = as_string_view_array(&args[2])?;
167
168 let mut builder = GenericStringBuilder::<i32>::new();
169 let mut buffer = String::new();
170
171 for ((string, from), to) in string_array
172 .iter()
173 .zip(from_array.iter())
174 .zip(to_array.iter())
175 {
176 match (string, from, to) {
177 (Some(string), Some(from), Some(to)) => {
178 buffer.clear();
179 replace_into_string(&mut buffer, string, from, to);
180 builder.append_value(&buffer);
181 }
182 _ => builder.append_null(),
183 }
184 }
185
186 Ok(Arc::new(builder.finish()) as ArrayRef)
187}
188
189fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
192 let string_array = as_generic_string_array::<T>(&args[0])?;
193 let from_array = as_generic_string_array::<T>(&args[1])?;
194 let to_array = as_generic_string_array::<T>(&args[2])?;
195
196 let mut builder = GenericStringBuilder::<T>::new();
197 let mut buffer = String::new();
198
199 for ((string, from), to) in string_array
200 .iter()
201 .zip(from_array.iter())
202 .zip(to_array.iter())
203 {
204 match (string, from, to) {
205 (Some(string), Some(from), Some(to)) => {
206 buffer.clear();
207 replace_into_string(&mut buffer, string, from, to);
208 builder.append_value(&buffer);
209 }
210 _ => builder.append_null(),
211 }
212 }
213
214 Ok(Arc::new(builder.finish()) as ArrayRef)
215}
216
217#[inline]
219fn replace_into_string(buffer: &mut String, string: &str, from: &str, to: &str) {
220 if from.is_empty() {
221 buffer.push_str(to);
224 for ch in string.chars() {
225 buffer.push(ch);
226 buffer.push_str(to);
227 }
228 return;
229 }
230
231 if let ([from_byte], [to_byte]) = (from.as_bytes(), to.as_bytes())
234 && from_byte.is_ascii()
235 && to_byte.is_ascii()
236 {
237 unsafe {
239 buffer.as_mut_vec().extend(
240 string
241 .as_bytes()
242 .iter()
243 .map(|&b| if b == *from_byte { *to_byte } else { b }),
244 );
245 }
246 return;
247 }
248
249 let mut last_end = 0;
250 for (start, _part) in string.match_indices(from) {
251 buffer.push_str(&string[last_end..start]);
252 buffer.push_str(to);
253 last_end = start + from.len();
254 }
255 buffer.push_str(&string[last_end..]);
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::utils::test::test_function;
262 use arrow::array::Array;
263 use arrow::array::LargeStringArray;
264 use arrow::array::StringArray;
265 use arrow::datatypes::DataType::{LargeUtf8, Utf8};
266 use datafusion_common::ScalarValue;
267 #[test]
268 fn test_functions() -> Result<()> {
269 test_function!(
270 ReplaceFunc::new(),
271 vec![
272 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))),
273 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))),
274 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))),
275 ],
276 Ok(Some("aacccdqcccc")),
277 &str,
278 Utf8,
279 StringArray
280 );
281
282 test_function!(
283 ReplaceFunc::new(),
284 vec![
285 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
286 "aabbb"
287 )))),
288 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))),
289 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))),
290 ],
291 Ok(Some("aacc")),
292 &str,
293 LargeUtf8,
294 LargeStringArray
295 );
296
297 test_function!(
298 ReplaceFunc::new(),
299 vec![
300 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
301 "aabbbcw"
302 )))),
303 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))),
304 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))),
305 ],
306 Ok(Some("aaccbcw")),
307 &str,
308 Utf8,
309 StringArray
310 );
311
312 Ok(())
313 }
314}