1use arrow::array::{
19 ArrayAccessor, ArrayIter, ArrayRef, AsArray, LargeStringBuilder, StringBuilder,
20 StringLikeArrayBuilder, StringViewBuilder,
21};
22use arrow::datatypes::DataType;
23use datafusion_common::HashMap;
24
25use crate::utils::make_scalar_function;
26use datafusion_common::{Result, exec_err};
27use datafusion_expr::TypeSignature::Exact;
28use datafusion_expr::{
29 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30 Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35 doc_section(label = "String Functions"),
36 description = "Performs character-wise substitution based on a mapping.",
37 syntax_example = "translate(str, from, to)",
38 sql_example = r#"```sql
39> select translate('twice', 'wic', 'her');
40+--------------------------------------------------+
41| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) |
42+--------------------------------------------------+
43| there |
44+--------------------------------------------------+
45```"#,
46 standard_argument(name = "str", prefix = "String"),
47 argument(name = "from", description = "The characters to be replaced."),
48 argument(
49 name = "to",
50 description = "The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping."
51 )
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct TranslateFunc {
55 signature: Signature,
56}
57
58impl Default for TranslateFunc {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl TranslateFunc {
65 pub fn new() -> Self {
66 use DataType::*;
67 Self {
68 signature: Signature::one_of(
69 vec![
70 Exact(vec![Utf8View, Utf8, Utf8]),
71 Exact(vec![Utf8, Utf8, Utf8]),
72 Exact(vec![LargeUtf8, Utf8, Utf8]),
73 ],
74 Volatility::Immutable,
75 ),
76 }
77 }
78}
79
80impl ScalarUDFImpl for TranslateFunc {
81 fn name(&self) -> &str {
82 "translate"
83 }
84
85 fn signature(&self) -> &Signature {
86 &self.signature
87 }
88
89 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
90 Ok(arg_types[0].clone())
91 }
92
93 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
94 if let (Some(from_str), Some(to_str)) = (
96 try_as_scalar_str(&args.args[1]),
97 try_as_scalar_str(&args.args[2]),
98 ) {
99 let to_chars: Vec<char> = to_str.chars().collect();
100
101 let mut from_map: HashMap<char, usize> = HashMap::new();
102 for (index, c) in from_str.chars().enumerate() {
103 from_map.entry(c).or_insert(index);
104 }
105
106 let ascii_table = build_ascii_translate_table(from_str, to_str);
107
108 let string_array = args.args[0].to_array_of_size(args.number_rows)?;
109 let len = string_array.len();
110
111 let result = match string_array.data_type() {
112 DataType::Utf8View => {
113 let arr = string_array.as_string_view();
114 let builder = StringViewBuilder::with_capacity(len);
115 translate_with_map(
116 arr,
117 &from_map,
118 &to_chars,
119 ascii_table.as_ref(),
120 builder,
121 )
122 }
123 DataType::Utf8 => {
124 let arr = string_array.as_string::<i32>();
125 let builder =
126 StringBuilder::with_capacity(len, arr.value_data().len());
127 translate_with_map(
128 arr,
129 &from_map,
130 &to_chars,
131 ascii_table.as_ref(),
132 builder,
133 )
134 }
135 DataType::LargeUtf8 => {
136 let arr = string_array.as_string::<i64>();
137 let builder =
138 LargeStringBuilder::with_capacity(len, arr.value_data().len());
139 translate_with_map(
140 arr,
141 &from_map,
142 &to_chars,
143 ascii_table.as_ref(),
144 builder,
145 )
146 }
147 other => {
148 return exec_err!(
149 "Unsupported data type {other:?} for function translate"
150 );
151 }
152 }?;
153
154 return Ok(ColumnarValue::Array(result));
155 }
156
157 make_scalar_function(invoke_translate, vec![])(&args.args)
158 }
159
160 fn documentation(&self) -> Option<&Documentation> {
161 self.doc()
162 }
163}
164
165use super::common::try_as_scalar_str;
166
167fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
168 let len = args[0].len();
169 match args[0].data_type() {
170 DataType::Utf8View => {
171 let string_array = args[0].as_string_view();
172 let from_array = args[1].as_string::<i32>();
173 let to_array = args[2].as_string::<i32>();
174 let builder = StringViewBuilder::with_capacity(len);
175 translate(string_array, from_array, to_array, builder)
176 }
177 DataType::Utf8 => {
178 let string_array = args[0].as_string::<i32>();
179 let from_array = args[1].as_string::<i32>();
180 let to_array = args[2].as_string::<i32>();
181 let builder =
182 StringBuilder::with_capacity(len, string_array.value_data().len());
183 translate(string_array, from_array, to_array, builder)
184 }
185 DataType::LargeUtf8 => {
186 let string_array = args[0].as_string::<i64>();
187 let from_array = args[1].as_string::<i32>();
188 let to_array = args[2].as_string::<i32>();
189 let builder =
190 LargeStringBuilder::with_capacity(len, string_array.value_data().len());
191 translate(string_array, from_array, to_array, builder)
192 }
193 other => {
194 exec_err!("Unsupported data type {other:?} for function translate")
195 }
196 }
197}
198
199fn translate<'a, V, B, O>(
202 string_array: V,
203 from_array: B,
204 to_array: B,
205 mut builder: O,
206) -> Result<ArrayRef>
207where
208 V: ArrayAccessor<Item = &'a str>,
209 B: ArrayAccessor<Item = &'a str>,
210 O: StringLikeArrayBuilder,
211{
212 let string_array_iter = ArrayIter::new(string_array);
213 let from_array_iter = ArrayIter::new(from_array);
214 let to_array_iter = ArrayIter::new(to_array);
215
216 let mut from_map: HashMap<char, usize> = HashMap::new();
217 let mut to_chars: Vec<char> = Vec::new();
218 let mut result_buf = String::new();
219
220 for ((string, from), to) in string_array_iter.zip(from_array_iter).zip(to_array_iter)
221 {
222 match (string, from, to) {
223 (Some(string), Some(from), Some(to)) => {
224 from_map.clear();
225 to_chars.clear();
226 result_buf.clear();
227
228 for (index, c) in from.chars().enumerate() {
229 from_map.entry(c).or_insert(index);
230 }
231
232 to_chars.extend(to.chars());
233
234 translate_char_by_char(string, &from_map, &to_chars, &mut result_buf);
235
236 builder.append_value(&result_buf);
237 }
238 _ => builder.append_null(),
239 }
240 }
241
242 Ok(builder.finish())
243}
244
245#[inline]
248fn translate_char_by_char(
249 input: &str,
250 from_map: &HashMap<char, usize>,
251 to_chars: &[char],
252 buf: &mut String,
253) {
254 for c in input.chars() {
255 match from_map.get(&c) {
256 Some(n) => {
257 if let Some(&replacement) = to_chars.get(*n) {
258 buf.push(replacement);
259 }
260 }
261 None => buf.push(c),
262 }
263 }
264}
265
266const ASCII_DELETE: u8 = 0xFF;
270
271fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> {
276 if !from.is_ascii() || !to.is_ascii() {
277 return None;
278 }
279 let mut table = [0u8; 128];
280 for i in 0..128u8 {
281 table[i as usize] = i;
282 }
283 let to_bytes = to.as_bytes();
284 let mut seen = [false; 128];
285 for (i, from_byte) in from.bytes().enumerate() {
286 let idx = from_byte as usize;
287 if !seen[idx] {
288 seen[idx] = true;
289 if i < to_bytes.len() {
290 table[idx] = to_bytes[i];
291 } else {
292 table[idx] = ASCII_DELETE;
293 }
294 }
295 }
296 Some(table)
297}
298
299fn translate_with_map<'a, V, O>(
304 string_array: V,
305 from_map: &HashMap<char, usize>,
306 to_chars: &[char],
307 ascii_table: Option<&[u8; 128]>,
308 mut builder: O,
309) -> Result<ArrayRef>
310where
311 V: ArrayAccessor<Item = &'a str>,
312 O: StringLikeArrayBuilder,
313{
314 let mut result_buf = String::new();
315 let mut ascii_buf: Vec<u8> = Vec::new();
316
317 for string in ArrayIter::new(string_array) {
318 match string {
319 Some(s) => {
320 if let Some(table) = ascii_table
322 && s.is_ascii()
323 {
324 ascii_buf.clear();
325 for &b in s.as_bytes() {
326 let mapped = table[b as usize];
327 if mapped != ASCII_DELETE {
328 ascii_buf.push(mapped);
329 }
330 }
331 builder.append_value(unsafe {
333 std::str::from_utf8_unchecked(&ascii_buf)
334 });
335 } else {
336 result_buf.clear();
337 translate_char_by_char(s, from_map, to_chars, &mut result_buf);
338 builder.append_value(&result_buf);
339 }
340 }
341 None => builder.append_null(),
342 }
343 }
344
345 Ok(builder.finish())
346}
347
348#[cfg(test)]
349mod tests {
350 use arrow::array::{Array, StringArray, StringViewArray};
351 use arrow::datatypes::DataType::{Utf8, Utf8View};
352
353 use datafusion_common::{Result, ScalarValue};
354 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
355
356 use crate::unicode::translate::TranslateFunc;
357 use crate::utils::test::test_function;
358
359 #[test]
360 fn test_functions() -> Result<()> {
361 test_function!(
362 TranslateFunc::new(),
363 vec![
364 ColumnarValue::Scalar(ScalarValue::from("12345")),
365 ColumnarValue::Scalar(ScalarValue::from("143")),
366 ColumnarValue::Scalar(ScalarValue::from("ax"))
367 ],
368 Ok(Some("a2x5")),
369 &str,
370 Utf8,
371 StringArray
372 );
373 test_function!(
374 TranslateFunc::new(),
375 vec![
376 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
377 ColumnarValue::Scalar(ScalarValue::from("143")),
378 ColumnarValue::Scalar(ScalarValue::from("ax"))
379 ],
380 Ok(None),
381 &str,
382 Utf8,
383 StringArray
384 );
385 test_function!(
386 TranslateFunc::new(),
387 vec![
388 ColumnarValue::Scalar(ScalarValue::from("12345")),
389 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
390 ColumnarValue::Scalar(ScalarValue::from("ax"))
391 ],
392 Ok(None),
393 &str,
394 Utf8,
395 StringArray
396 );
397 test_function!(
398 TranslateFunc::new(),
399 vec![
400 ColumnarValue::Scalar(ScalarValue::from("12345")),
401 ColumnarValue::Scalar(ScalarValue::from("143")),
402 ColumnarValue::Scalar(ScalarValue::Utf8(None))
403 ],
404 Ok(None),
405 &str,
406 Utf8,
407 StringArray
408 );
409 test_function!(
410 TranslateFunc::new(),
411 vec![
412 ColumnarValue::Scalar(ScalarValue::from("abcabc")),
413 ColumnarValue::Scalar(ScalarValue::from("aa")),
414 ColumnarValue::Scalar(ScalarValue::from("de"))
415 ],
416 Ok(Some("dbcdbc")),
417 &str,
418 Utf8,
419 StringArray
420 );
421 test_function!(
422 TranslateFunc::new(),
423 vec![
424 ColumnarValue::Scalar(ScalarValue::from("é2íñ5")),
425 ColumnarValue::Scalar(ScalarValue::from("éñí")),
426 ColumnarValue::Scalar(ScalarValue::from("óü")),
427 ],
428 Ok(Some("ó2ü5")),
429 &str,
430 Utf8,
431 StringArray
432 );
433 test_function!(
436 TranslateFunc::new(),
437 vec![
438 ColumnarValue::Scalar(ScalarValue::from("café")),
439 ColumnarValue::Scalar(ScalarValue::from("ae")),
440 ColumnarValue::Scalar(ScalarValue::from("AE"))
441 ],
442 Ok(Some("cAfé")),
443 &str,
444 Utf8,
445 StringArray
446 );
447 test_function!(
449 TranslateFunc::new(),
450 vec![
451 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("12345".into()))),
452 ColumnarValue::Scalar(ScalarValue::from("143")),
453 ColumnarValue::Scalar(ScalarValue::from("ax"))
454 ],
455 Ok(Some("a2x5")),
456 &str,
457 Utf8View,
458 StringViewArray
459 );
460 test_function!(
462 TranslateFunc::new(),
463 vec![
464 ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
465 ColumnarValue::Scalar(ScalarValue::from("143")),
466 ColumnarValue::Scalar(ScalarValue::from("ax"))
467 ],
468 Ok(None),
469 &str,
470 Utf8View,
471 StringViewArray
472 );
473 test_function!(
475 TranslateFunc::new(),
476 vec![
477 ColumnarValue::Scalar(ScalarValue::Utf8View(Some("é2íñ5".into()))),
478 ColumnarValue::Scalar(ScalarValue::from("éñí")),
479 ColumnarValue::Scalar(ScalarValue::from("óü"))
480 ],
481 Ok(Some("ó2ü5")),
482 &str,
483 Utf8View,
484 StringViewArray
485 );
486
487 #[cfg(not(feature = "unicode_expressions"))]
488 test_function!(
489 TranslateFunc::new(),
490 vec![
491 ColumnarValue::Scalar(ScalarValue::from("12345")),
492 ColumnarValue::Scalar(ScalarValue::from("143")),
493 ColumnarValue::Scalar(ScalarValue::from("ax")),
494 ],
495 internal_err!(
496 "function translate requires compilation with feature flag: unicode_expressions."
497 ),
498 &str,
499 Utf8,
500 StringArray
501 );
502
503 Ok(())
504 }
505}