1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22 ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait,
23};
24use arrow::datatypes::DataType;
25use datafusion_common::HashMap;
26use unicode_segmentation::UnicodeSegmentation;
27
28use crate::utils::{make_scalar_function, utf8_to_str_type};
29use datafusion_common::{Result, exec_err};
30use datafusion_expr::TypeSignature::Exact;
31use datafusion_expr::{
32 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "String Functions"),
38 description = "Performs character-wise substitution based on a mapping.",
39 syntax_example = "translate(str, from, to)",
40 sql_example = r#"```sql
41> select translate('twice', 'wic', 'her');
42+--------------------------------------------------+
43| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) |
44+--------------------------------------------------+
45| there |
46+--------------------------------------------------+
47```"#,
48 standard_argument(name = "str", prefix = "String"),
49 argument(name = "from", description = "The characters to be replaced."),
50 argument(
51 name = "to",
52 description = "The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping."
53 )
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct TranslateFunc {
57 signature: Signature,
58}
59
60impl Default for TranslateFunc {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl TranslateFunc {
67 pub fn new() -> Self {
68 use DataType::*;
69 Self {
70 signature: Signature::one_of(
71 vec![
72 Exact(vec![Utf8View, Utf8, Utf8]),
73 Exact(vec![Utf8, Utf8, Utf8]),
74 Exact(vec![LargeUtf8, Utf8, Utf8]),
75 ],
76 Volatility::Immutable,
77 ),
78 }
79 }
80}
81
82impl ScalarUDFImpl for TranslateFunc {
83 fn as_any(&self) -> &dyn Any {
84 self
85 }
86
87 fn name(&self) -> &str {
88 "translate"
89 }
90
91 fn signature(&self) -> &Signature {
92 &self.signature
93 }
94
95 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
96 utf8_to_str_type(&arg_types[0], "translate")
97 }
98
99 fn invoke_with_args(
100 &self,
101 args: datafusion_expr::ScalarFunctionArgs,
102 ) -> Result<ColumnarValue> {
103 if let (Some(from_str), Some(to_str)) = (
105 try_as_scalar_str(&args.args[1]),
106 try_as_scalar_str(&args.args[2]),
107 ) {
108 let to_graphemes: Vec<&str> = to_str.graphemes(true).collect();
109
110 let mut from_map: HashMap<&str, usize> = HashMap::new();
111 for (index, c) in from_str.graphemes(true).enumerate() {
112 from_map.entry(c).or_insert(index);
114 }
115
116 let ascii_table = build_ascii_translate_table(from_str, to_str);
117
118 let string_array = args.args[0].to_array_of_size(args.number_rows)?;
119
120 let result = match string_array.data_type() {
121 DataType::Utf8View => {
122 let arr = string_array.as_string_view();
123 translate_with_map::<i32, _>(
124 arr,
125 &from_map,
126 &to_graphemes,
127 ascii_table.as_ref(),
128 )
129 }
130 DataType::Utf8 => {
131 let arr = string_array.as_string::<i32>();
132 translate_with_map::<i32, _>(
133 arr,
134 &from_map,
135 &to_graphemes,
136 ascii_table.as_ref(),
137 )
138 }
139 DataType::LargeUtf8 => {
140 let arr = string_array.as_string::<i64>();
141 translate_with_map::<i64, _>(
142 arr,
143 &from_map,
144 &to_graphemes,
145 ascii_table.as_ref(),
146 )
147 }
148 other => {
149 return exec_err!(
150 "Unsupported data type {other:?} for function translate"
151 );
152 }
153 }?;
154
155 return Ok(ColumnarValue::Array(result));
156 }
157
158 make_scalar_function(invoke_translate, vec![])(&args.args)
159 }
160
161 fn documentation(&self) -> Option<&Documentation> {
162 self.doc()
163 }
164}
165
166fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> {
168 match cv {
169 ColumnarValue::Scalar(s) => s.try_as_str().flatten(),
170 _ => None,
171 }
172}
173
174fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
175 match args[0].data_type() {
176 DataType::Utf8View => {
177 let string_array = args[0].as_string_view();
178 let from_array = args[1].as_string::<i32>();
179 let to_array = args[2].as_string::<i32>();
180 translate::<i32, _, _>(string_array, from_array, to_array)
181 }
182 DataType::Utf8 => {
183 let string_array = args[0].as_string::<i32>();
184 let from_array = args[1].as_string::<i32>();
185 let to_array = args[2].as_string::<i32>();
186 translate::<i32, _, _>(string_array, from_array, to_array)
187 }
188 DataType::LargeUtf8 => {
189 let string_array = args[0].as_string::<i64>();
190 let from_array = args[1].as_string::<i32>();
191 let to_array = args[2].as_string::<i32>();
192 translate::<i64, _, _>(string_array, from_array, to_array)
193 }
194 other => {
195 exec_err!("Unsupported data type {other:?} for function translate")
196 }
197 }
198}
199
200fn translate<'a, T: OffsetSizeTrait, V, B>(
203 string_array: V,
204 from_array: B,
205 to_array: B,
206) -> Result<ArrayRef>
207where
208 V: ArrayAccessor<Item = &'a str>,
209 B: ArrayAccessor<Item = &'a str>,
210{
211 let string_array_iter = ArrayIter::new(string_array);
212 let from_array_iter = ArrayIter::new(from_array);
213 let to_array_iter = ArrayIter::new(to_array);
214
215 let mut from_map: HashMap<&str, usize> = HashMap::new();
217 let mut from_graphemes: Vec<&str> = Vec::new();
218 let mut to_graphemes: Vec<&str> = Vec::new();
219 let mut string_graphemes: Vec<&str> = Vec::new();
220 let mut result_graphemes: Vec<&str> = Vec::new();
221
222 let result = string_array_iter
223 .zip(from_array_iter)
224 .zip(to_array_iter)
225 .map(|((string, from), to)| match (string, from, to) {
226 (Some(string), Some(from), Some(to)) => {
227 from_map.clear();
229 from_graphemes.clear();
230 to_graphemes.clear();
231 string_graphemes.clear();
232 result_graphemes.clear();
233
234 from_graphemes.extend(from.graphemes(true));
236 for (index, c) in from_graphemes.iter().enumerate() {
237 from_map.entry(*c).or_insert(index);
239 }
240
241 to_graphemes.extend(to.graphemes(true));
243
244 string_graphemes.extend(string.graphemes(true));
246 for c in &string_graphemes {
247 match from_map.get(*c) {
248 Some(n) => {
249 if let Some(replacement) = to_graphemes.get(*n) {
250 result_graphemes.push(*replacement);
251 }
252 }
253 None => result_graphemes.push(*c),
254 }
255 }
256
257 Some(result_graphemes.concat())
258 }
259 _ => None,
260 })
261 .collect::<GenericStringArray<T>>();
262
263 Ok(Arc::new(result) as ArrayRef)
264}
265
266const ASCII_DELETE: u8 = 0xFF;
270
271fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> {
276 if !from.is_ascii() || !to.is_ascii() {
277 return None;
278 }
279 let mut table = [0u8; 128];
280 for i in 0..128u8 {
281 table[i as usize] = i;
282 }
283 let to_bytes = to.as_bytes();
284 let mut seen = [false; 128];
285 for (i, from_byte) in from.bytes().enumerate() {
286 let idx = from_byte as usize;
287 if !seen[idx] {
288 seen[idx] = true;
289 if i < to_bytes.len() {
290 table[idx] = to_bytes[i];
291 } else {
292 table[idx] = ASCII_DELETE;
293 }
294 }
295 }
296 Some(table)
297}
298
299fn translate_with_map<'a, T: OffsetSizeTrait, V>(
304 string_array: V,
305 from_map: &HashMap<&str, usize>,
306 to_graphemes: &[&str],
307 ascii_table: Option<&[u8; 128]>,
308) -> Result<ArrayRef>
309where
310 V: ArrayAccessor<Item = &'a str>,
311{
312 let mut result_graphemes: Vec<&str> = Vec::new();
313 let mut ascii_buf: Vec<u8> = Vec::new();
314
315 let result = ArrayIter::new(string_array)
316 .map(|string| {
317 string.map(|s| {
318 if let Some(table) = ascii_table
320 && s.is_ascii()
321 {
322 ascii_buf.clear();
323 for &b in s.as_bytes() {
324 let mapped = table[b as usize];
325 if mapped != ASCII_DELETE {
326 ascii_buf.push(mapped);
327 }
328 }
329 return unsafe {
331 std::str::from_utf8_unchecked(&ascii_buf).to_owned()
332 };
333 }
334
335 result_graphemes.clear();
337
338 for c in s.graphemes(true) {
339 match from_map.get(c) {
340 Some(n) => {
341 if let Some(replacement) = to_graphemes.get(*n) {
342 result_graphemes.push(*replacement);
343 }
344 }
345 None => result_graphemes.push(c),
346 }
347 }
348
349 result_graphemes.concat()
350 })
351 })
352 .collect::<GenericStringArray<T>>();
353
354 Ok(Arc::new(result) as ArrayRef)
355}
356
357#[cfg(test)]
358mod tests {
359 use arrow::array::{Array, StringArray};
360 use arrow::datatypes::DataType::Utf8;
361
362 use datafusion_common::{Result, ScalarValue};
363 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
364
365 use crate::unicode::translate::TranslateFunc;
366 use crate::utils::test::test_function;
367
368 #[test]
369 fn test_functions() -> Result<()> {
370 test_function!(
371 TranslateFunc::new(),
372 vec![
373 ColumnarValue::Scalar(ScalarValue::from("12345")),
374 ColumnarValue::Scalar(ScalarValue::from("143")),
375 ColumnarValue::Scalar(ScalarValue::from("ax"))
376 ],
377 Ok(Some("a2x5")),
378 &str,
379 Utf8,
380 StringArray
381 );
382 test_function!(
383 TranslateFunc::new(),
384 vec![
385 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
386 ColumnarValue::Scalar(ScalarValue::from("143")),
387 ColumnarValue::Scalar(ScalarValue::from("ax"))
388 ],
389 Ok(None),
390 &str,
391 Utf8,
392 StringArray
393 );
394 test_function!(
395 TranslateFunc::new(),
396 vec![
397 ColumnarValue::Scalar(ScalarValue::from("12345")),
398 ColumnarValue::Scalar(ScalarValue::Utf8(None)),
399 ColumnarValue::Scalar(ScalarValue::from("ax"))
400 ],
401 Ok(None),
402 &str,
403 Utf8,
404 StringArray
405 );
406 test_function!(
407 TranslateFunc::new(),
408 vec![
409 ColumnarValue::Scalar(ScalarValue::from("12345")),
410 ColumnarValue::Scalar(ScalarValue::from("143")),
411 ColumnarValue::Scalar(ScalarValue::Utf8(None))
412 ],
413 Ok(None),
414 &str,
415 Utf8,
416 StringArray
417 );
418 test_function!(
419 TranslateFunc::new(),
420 vec![
421 ColumnarValue::Scalar(ScalarValue::from("abcabc")),
422 ColumnarValue::Scalar(ScalarValue::from("aa")),
423 ColumnarValue::Scalar(ScalarValue::from("de"))
424 ],
425 Ok(Some("dbcdbc")),
426 &str,
427 Utf8,
428 StringArray
429 );
430 test_function!(
431 TranslateFunc::new(),
432 vec![
433 ColumnarValue::Scalar(ScalarValue::from("é2íñ5")),
434 ColumnarValue::Scalar(ScalarValue::from("éñí")),
435 ColumnarValue::Scalar(ScalarValue::from("óü")),
436 ],
437 Ok(Some("ó2ü5")),
438 &str,
439 Utf8,
440 StringArray
441 );
442 test_function!(
445 TranslateFunc::new(),
446 vec![
447 ColumnarValue::Scalar(ScalarValue::from("café")),
448 ColumnarValue::Scalar(ScalarValue::from("ae")),
449 ColumnarValue::Scalar(ScalarValue::from("AE"))
450 ],
451 Ok(Some("cAfé")),
452 &str,
453 Utf8,
454 StringArray
455 );
456
457 #[cfg(not(feature = "unicode_expressions"))]
458 test_function!(
459 TranslateFunc::new(),
460 vec![
461 ColumnarValue::Scalar(ScalarValue::from("12345")),
462 ColumnarValue::Scalar(ScalarValue::from("143")),
463 ColumnarValue::Scalar(ScalarValue::from("ax")),
464 ],
465 internal_err!(
466 "function translate requires compilation with feature flag: unicode_expressions."
467 ),
468 &str,
469 Utf8,
470 StringArray
471 );
472
473 Ok(())
474 }
475}