datafusion_functions/string/
replace.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, GenericStringBuilder, OffsetSizeTrait};
22use arrow::datatypes::DataType;
23
24use crate::utils::{make_scalar_function, utf8_to_str_type};
25use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
26use datafusion_common::types::logical_string;
27use datafusion_common::{Result, exec_err};
28use datafusion_expr::type_coercion::binary::{
29 binary_to_string_coercion, string_coercion,
30};
31use datafusion_expr::{
32 Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33 TypeSignatureClass, Volatility,
34};
35use datafusion_macros::user_doc;
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Replaces all occurrences of a specified substring in a string with a new substring.",
39 syntax_example = "replace(str, substr, replacement)",
40 sql_example = r#"```sql
41> select replace('ABabbaBA', 'ab', 'cd');
42+-------------------------------------------------+
43| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) |
44+-------------------------------------------------+
45| ABcdbaBA |
46+-------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 standard_argument(
50 name = "substr",
51 prefix = "Substring expression to replace in the input string. Substring"
52 ),
53 standard_argument(name = "replacement", prefix = "Replacement substring")
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct ReplaceFunc {
57 signature: Signature,
58}
59
60impl Default for ReplaceFunc {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl ReplaceFunc {
67 pub fn new() -> Self {
68 Self {
69 signature: Signature::coercible(
70 vec![
71 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73 Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
74 ],
75 Volatility::Immutable,
76 ),
77 }
78 }
79}
80
81impl ScalarUDFImpl for ReplaceFunc {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn name(&self) -> &str {
87 "replace"
88 }
89
90 fn signature(&self) -> &Signature {
91 &self.signature
92 }
93
94 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95 if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1])
96 .and_then(|dt| string_coercion(&dt, &arg_types[2]))
97 .or_else(|| {
98 binary_to_string_coercion(&arg_types[0], &arg_types[1])
99 .and_then(|dt| binary_to_string_coercion(&dt, &arg_types[2]))
100 })
101 {
102 utf8_to_str_type(&coercion_data_type, "replace")
103 } else {
104 exec_err!(
105 "Unsupported data types for replace. Expected Utf8, LargeUtf8 or Utf8View"
106 )
107 }
108 }
109
110 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
111 let data_types = args
112 .args
113 .iter()
114 .map(|arg| arg.data_type())
115 .collect::<Vec<_>>();
116
117 if let Some(coercion_type) = string_coercion(&data_types[0], &data_types[1])
118 .and_then(|dt| string_coercion(&dt, &data_types[2]))
119 .or_else(|| {
120 binary_to_string_coercion(&data_types[0], &data_types[1])
121 .and_then(|dt| binary_to_string_coercion(&dt, &data_types[2]))
122 })
123 {
124 let mut converted_args = Vec::with_capacity(args.args.len());
125 for arg in &args.args {
126 if arg.data_type() == coercion_type {
127 converted_args.push(arg.clone());
128 } else {
129 let converted = arg.cast_to(&coercion_type, None)?;
130 converted_args.push(converted);
131 }
132 }
133
134 match coercion_type {
135 DataType::Utf8 => {
136 make_scalar_function(replace::<i32>, vec![])(&converted_args)
137 }
138 DataType::LargeUtf8 => {
139 make_scalar_function(replace::<i64>, vec![])(&converted_args)
140 }
141 DataType::Utf8View => {
142 make_scalar_function(replace_view, vec![])(&converted_args)
143 }
144 other => exec_err!(
145 "Unsupported coercion data type {other:?} for function replace"
146 ),
147 }
148 } else {
149 exec_err!(
150 "Unsupported data type {}, {:?}, {:?} for function replace.",
151 data_types[0],
152 data_types[1],
153 data_types[2]
154 )
155 }
156 }
157
158 fn documentation(&self) -> Option<&Documentation> {
159 self.doc()
160 }
161}
162
163fn replace_view(args: &[ArrayRef]) -> Result<ArrayRef> {
164 let string_array = as_string_view_array(&args[0])?;
165 let from_array = as_string_view_array(&args[1])?;
166 let to_array = as_string_view_array(&args[2])?;
167
168 let mut builder = GenericStringBuilder::<i32>::new();
169 let mut buffer = String::new();
170
171 for ((string, from), to) in string_array
172 .iter()
173 .zip(from_array.iter())
174 .zip(to_array.iter())
175 {
176 match (string, from, to) {
177 (Some(string), Some(from), Some(to)) => {
178 buffer.clear();
179 replace_into_string(&mut buffer, string, from, to);
180 builder.append_value(&buffer);
181 }
182 _ => builder.append_null(),
183 }
184 }
185
186 Ok(Arc::new(builder.finish()) as ArrayRef)
187}
188
189fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
192 let string_array = as_generic_string_array::<T>(&args[0])?;
193 let from_array = as_generic_string_array::<T>(&args[1])?;
194 let to_array = as_generic_string_array::<T>(&args[2])?;
195
196 let mut builder = GenericStringBuilder::<T>::new();
197 let mut buffer = String::new();
198
199 for ((string, from), to) in string_array
200 .iter()
201 .zip(from_array.iter())
202 .zip(to_array.iter())
203 {
204 match (string, from, to) {
205 (Some(string), Some(from), Some(to)) => {
206 buffer.clear();
207 replace_into_string(&mut buffer, string, from, to);
208 builder.append_value(&buffer);
209 }
210 _ => builder.append_null(),
211 }
212 }
213
214 Ok(Arc::new(builder.finish()) as ArrayRef)
215}
216
217#[inline]
219fn replace_into_string(buffer: &mut String, string: &str, from: &str, to: &str) {
220 if from.is_empty() {
221 buffer.push_str(to);
224 for ch in string.chars() {
225 buffer.push(ch);
226 buffer.push_str(to);
227 }
228 return;
229 }
230
231 if let ([from_byte], [to_byte]) = (from.as_bytes(), to.as_bytes())
234 && from_byte.is_ascii()
235 && to_byte.is_ascii()
236 {
237 let replaced: Vec<u8> = string
239 .as_bytes()
240 .iter()
241 .map(|b| if *b == *from_byte { *to_byte } else { *b })
242 .collect();
243 buffer.push_str(unsafe { std::str::from_utf8_unchecked(&replaced) });
244 return;
245 }
246
247 let mut last_end = 0;
248 for (start, _part) in string.match_indices(from) {
249 buffer.push_str(&string[last_end..start]);
250 buffer.push_str(to);
251 last_end = start + from.len();
252 }
253 buffer.push_str(&string[last_end..]);
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::utils::test::test_function;
260 use arrow::array::Array;
261 use arrow::array::LargeStringArray;
262 use arrow::array::StringArray;
263 use arrow::datatypes::DataType::{LargeUtf8, Utf8};
264 use datafusion_common::ScalarValue;
265 #[test]
266 fn test_functions() -> Result<()> {
267 test_function!(
268 ReplaceFunc::new(),
269 vec![
270 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))),
271 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))),
272 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))),
273 ],
274 Ok(Some("aacccdqcccc")),
275 &str,
276 Utf8,
277 StringArray
278 );
279
280 test_function!(
281 ReplaceFunc::new(),
282 vec![
283 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
284 "aabbb"
285 )))),
286 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))),
287 ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))),
288 ],
289 Ok(Some("aacc")),
290 &str,
291 LargeUtf8,
292 LargeStringArray
293 );
294
295 test_function!(
296 ReplaceFunc::new(),
297 vec![
298 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
299 "aabbbcw"
300 )))),
301 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))),
302 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))),
303 ],
304 Ok(Some("aaccbcw")),
305 &str,
306 Utf8,
307 StringArray
308 );
309
310 Ok(())
311 }
312}