Skip to main content

datafusion_functions/unicode/
translate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait,
23};
24use arrow::datatypes::DataType;
25use datafusion_common::HashMap;
26use unicode_segmentation::UnicodeSegmentation;
27
28use crate::utils::{make_scalar_function, utf8_to_str_type};
29use datafusion_common::{Result, exec_err};
30use datafusion_expr::TypeSignature::Exact;
31use datafusion_expr::{
32    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Performs character-wise substitution based on a mapping.",
39    syntax_example = "translate(str, from, to)",
40    sql_example = r#"```sql
41> select translate('twice', 'wic', 'her');
42+--------------------------------------------------+
43| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) |
44+--------------------------------------------------+
45| there                                            |
46+--------------------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    argument(name = "from", description = "The characters to be replaced."),
50    argument(
51        name = "to",
52        description = "The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping."
53    )
54)]
55#[derive(Debug, PartialEq, Eq, Hash)]
56pub struct TranslateFunc {
57    signature: Signature,
58}
59
60impl Default for TranslateFunc {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl TranslateFunc {
67    pub fn new() -> Self {
68        use DataType::*;
69        Self {
70            signature: Signature::one_of(
71                vec![
72                    Exact(vec![Utf8View, Utf8, Utf8]),
73                    Exact(vec![Utf8, Utf8, Utf8]),
74                    Exact(vec![LargeUtf8, Utf8, Utf8]),
75                ],
76                Volatility::Immutable,
77            ),
78        }
79    }
80}
81
82impl ScalarUDFImpl for TranslateFunc {
83    fn as_any(&self) -> &dyn Any {
84        self
85    }
86
87    fn name(&self) -> &str {
88        "translate"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
96        utf8_to_str_type(&arg_types[0], "translate")
97    }
98
99    fn invoke_with_args(
100        &self,
101        args: datafusion_expr::ScalarFunctionArgs,
102    ) -> Result<ColumnarValue> {
103        // When from and to are scalars, pre-build the translation map once
104        if let (Some(from_str), Some(to_str)) = (
105            try_as_scalar_str(&args.args[1]),
106            try_as_scalar_str(&args.args[2]),
107        ) {
108            let to_graphemes: Vec<&str> = to_str.graphemes(true).collect();
109
110            let mut from_map: HashMap<&str, usize> = HashMap::new();
111            for (index, c) in from_str.graphemes(true).enumerate() {
112                // Ignore characters that already exist in from_map
113                from_map.entry(c).or_insert(index);
114            }
115
116            let ascii_table = build_ascii_translate_table(from_str, to_str);
117
118            let string_array = args.args[0].to_array_of_size(args.number_rows)?;
119
120            let result = match string_array.data_type() {
121                DataType::Utf8View => {
122                    let arr = string_array.as_string_view();
123                    translate_with_map::<i32, _>(
124                        arr,
125                        &from_map,
126                        &to_graphemes,
127                        ascii_table.as_ref(),
128                    )
129                }
130                DataType::Utf8 => {
131                    let arr = string_array.as_string::<i32>();
132                    translate_with_map::<i32, _>(
133                        arr,
134                        &from_map,
135                        &to_graphemes,
136                        ascii_table.as_ref(),
137                    )
138                }
139                DataType::LargeUtf8 => {
140                    let arr = string_array.as_string::<i64>();
141                    translate_with_map::<i64, _>(
142                        arr,
143                        &from_map,
144                        &to_graphemes,
145                        ascii_table.as_ref(),
146                    )
147                }
148                other => {
149                    return exec_err!(
150                        "Unsupported data type {other:?} for function translate"
151                    );
152                }
153            }?;
154
155            return Ok(ColumnarValue::Array(result));
156        }
157
158        make_scalar_function(invoke_translate, vec![])(&args.args)
159    }
160
161    fn documentation(&self) -> Option<&Documentation> {
162        self.doc()
163    }
164}
165
166/// If `cv` is a non-null scalar string, return its value.
167fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> {
168    match cv {
169        ColumnarValue::Scalar(s) => s.try_as_str().flatten(),
170        _ => None,
171    }
172}
173
174fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
175    match args[0].data_type() {
176        DataType::Utf8View => {
177            let string_array = args[0].as_string_view();
178            let from_array = args[1].as_string::<i32>();
179            let to_array = args[2].as_string::<i32>();
180            translate::<i32, _, _>(string_array, from_array, to_array)
181        }
182        DataType::Utf8 => {
183            let string_array = args[0].as_string::<i32>();
184            let from_array = args[1].as_string::<i32>();
185            let to_array = args[2].as_string::<i32>();
186            translate::<i32, _, _>(string_array, from_array, to_array)
187        }
188        DataType::LargeUtf8 => {
189            let string_array = args[0].as_string::<i64>();
190            let from_array = args[1].as_string::<i32>();
191            let to_array = args[2].as_string::<i32>();
192            translate::<i64, _, _>(string_array, from_array, to_array)
193        }
194        other => {
195            exec_err!("Unsupported data type {other:?} for function translate")
196        }
197    }
198}
199
200/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.
201/// translate('12345', '143', 'ax') = 'a2x5'
202fn translate<'a, T: OffsetSizeTrait, V, B>(
203    string_array: V,
204    from_array: B,
205    to_array: B,
206) -> Result<ArrayRef>
207where
208    V: ArrayAccessor<Item = &'a str>,
209    B: ArrayAccessor<Item = &'a str>,
210{
211    let string_array_iter = ArrayIter::new(string_array);
212    let from_array_iter = ArrayIter::new(from_array);
213    let to_array_iter = ArrayIter::new(to_array);
214
215    // Reusable buffers to avoid allocating for each row
216    let mut from_map: HashMap<&str, usize> = HashMap::new();
217    let mut from_graphemes: Vec<&str> = Vec::new();
218    let mut to_graphemes: Vec<&str> = Vec::new();
219    let mut string_graphemes: Vec<&str> = Vec::new();
220    let mut result_graphemes: Vec<&str> = Vec::new();
221
222    let result = string_array_iter
223        .zip(from_array_iter)
224        .zip(to_array_iter)
225        .map(|((string, from), to)| match (string, from, to) {
226            (Some(string), Some(from), Some(to)) => {
227                // Clear and reuse buffers
228                from_map.clear();
229                from_graphemes.clear();
230                to_graphemes.clear();
231                string_graphemes.clear();
232                result_graphemes.clear();
233
234                // Build from_map using reusable buffer
235                from_graphemes.extend(from.graphemes(true));
236                for (index, c) in from_graphemes.iter().enumerate() {
237                    // Ignore characters that already exist in from_map
238                    from_map.entry(*c).or_insert(index);
239                }
240
241                // Build to_graphemes
242                to_graphemes.extend(to.graphemes(true));
243
244                // Process string and build result
245                string_graphemes.extend(string.graphemes(true));
246                for c in &string_graphemes {
247                    match from_map.get(*c) {
248                        Some(n) => {
249                            if let Some(replacement) = to_graphemes.get(*n) {
250                                result_graphemes.push(*replacement);
251                            }
252                        }
253                        None => result_graphemes.push(*c),
254                    }
255                }
256
257                Some(result_graphemes.concat())
258            }
259            _ => None,
260        })
261        .collect::<GenericStringArray<T>>();
262
263    Ok(Arc::new(result) as ArrayRef)
264}
265
266/// Sentinel value in the ASCII translate table indicating the character should
267/// be deleted (the `from` character has no corresponding `to` character).  Any
268/// value > 127 works since valid ASCII is 0–127.
269const ASCII_DELETE: u8 = 0xFF;
270
271/// If `from` and `to` are both ASCII, build a fixed-size lookup table for
272/// translation. Each entry maps an input byte to its replacement byte, or to
273/// [`ASCII_DELETE`] if the character should be removed.  Returns `None` if
274/// either string contains non-ASCII characters.
275fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> {
276    if !from.is_ascii() || !to.is_ascii() {
277        return None;
278    }
279    let mut table = [0u8; 128];
280    for i in 0..128u8 {
281        table[i as usize] = i;
282    }
283    let to_bytes = to.as_bytes();
284    let mut seen = [false; 128];
285    for (i, from_byte) in from.bytes().enumerate() {
286        let idx = from_byte as usize;
287        if !seen[idx] {
288            seen[idx] = true;
289            if i < to_bytes.len() {
290                table[idx] = to_bytes[i];
291            } else {
292                table[idx] = ASCII_DELETE;
293            }
294        }
295    }
296    Some(table)
297}
298
299/// Optimized translate for constant `from` and `to` arguments: uses a pre-built
300/// translation map instead of rebuilding it for every row.  When an ASCII byte
301/// lookup table is provided, ASCII input rows use the lookup table; non-ASCII
302/// inputs fallback to using the map.
303fn translate_with_map<'a, T: OffsetSizeTrait, V>(
304    string_array: V,
305    from_map: &HashMap<&str, usize>,
306    to_graphemes: &[&str],
307    ascii_table: Option<&[u8; 128]>,
308) -> Result<ArrayRef>
309where
310    V: ArrayAccessor<Item = &'a str>,
311{
312    let mut result_graphemes: Vec<&str> = Vec::new();
313    let mut ascii_buf: Vec<u8> = Vec::new();
314
315    let result = ArrayIter::new(string_array)
316        .map(|string| {
317            string.map(|s| {
318                // Fast path: byte-level table lookup for ASCII strings
319                if let Some(table) = ascii_table
320                    && s.is_ascii()
321                {
322                    ascii_buf.clear();
323                    for &b in s.as_bytes() {
324                        let mapped = table[b as usize];
325                        if mapped != ASCII_DELETE {
326                            ascii_buf.push(mapped);
327                        }
328                    }
329                    // SAFETY: all bytes are ASCII, hence valid UTF-8.
330                    return unsafe {
331                        std::str::from_utf8_unchecked(&ascii_buf).to_owned()
332                    };
333                }
334
335                // Slow path: grapheme-based translation
336                result_graphemes.clear();
337
338                for c in s.graphemes(true) {
339                    match from_map.get(c) {
340                        Some(n) => {
341                            if let Some(replacement) = to_graphemes.get(*n) {
342                                result_graphemes.push(*replacement);
343                            }
344                        }
345                        None => result_graphemes.push(c),
346                    }
347                }
348
349                result_graphemes.concat()
350            })
351        })
352        .collect::<GenericStringArray<T>>();
353
354    Ok(Arc::new(result) as ArrayRef)
355}
356
357#[cfg(test)]
358mod tests {
359    use arrow::array::{Array, StringArray};
360    use arrow::datatypes::DataType::Utf8;
361
362    use datafusion_common::{Result, ScalarValue};
363    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
364
365    use crate::unicode::translate::TranslateFunc;
366    use crate::utils::test::test_function;
367
368    #[test]
369    fn test_functions() -> Result<()> {
370        test_function!(
371            TranslateFunc::new(),
372            vec![
373                ColumnarValue::Scalar(ScalarValue::from("12345")),
374                ColumnarValue::Scalar(ScalarValue::from("143")),
375                ColumnarValue::Scalar(ScalarValue::from("ax"))
376            ],
377            Ok(Some("a2x5")),
378            &str,
379            Utf8,
380            StringArray
381        );
382        test_function!(
383            TranslateFunc::new(),
384            vec![
385                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
386                ColumnarValue::Scalar(ScalarValue::from("143")),
387                ColumnarValue::Scalar(ScalarValue::from("ax"))
388            ],
389            Ok(None),
390            &str,
391            Utf8,
392            StringArray
393        );
394        test_function!(
395            TranslateFunc::new(),
396            vec![
397                ColumnarValue::Scalar(ScalarValue::from("12345")),
398                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
399                ColumnarValue::Scalar(ScalarValue::from("ax"))
400            ],
401            Ok(None),
402            &str,
403            Utf8,
404            StringArray
405        );
406        test_function!(
407            TranslateFunc::new(),
408            vec![
409                ColumnarValue::Scalar(ScalarValue::from("12345")),
410                ColumnarValue::Scalar(ScalarValue::from("143")),
411                ColumnarValue::Scalar(ScalarValue::Utf8(None))
412            ],
413            Ok(None),
414            &str,
415            Utf8,
416            StringArray
417        );
418        test_function!(
419            TranslateFunc::new(),
420            vec![
421                ColumnarValue::Scalar(ScalarValue::from("abcabc")),
422                ColumnarValue::Scalar(ScalarValue::from("aa")),
423                ColumnarValue::Scalar(ScalarValue::from("de"))
424            ],
425            Ok(Some("dbcdbc")),
426            &str,
427            Utf8,
428            StringArray
429        );
430        test_function!(
431            TranslateFunc::new(),
432            vec![
433                ColumnarValue::Scalar(ScalarValue::from("é2íñ5")),
434                ColumnarValue::Scalar(ScalarValue::from("éñí")),
435                ColumnarValue::Scalar(ScalarValue::from("óü")),
436            ],
437            Ok(Some("ó2ü5")),
438            &str,
439            Utf8,
440            StringArray
441        );
442        // Non-ASCII input with ASCII scalar from/to: exercises the
443        // grapheme fallback within translate_with_map.
444        test_function!(
445            TranslateFunc::new(),
446            vec![
447                ColumnarValue::Scalar(ScalarValue::from("café")),
448                ColumnarValue::Scalar(ScalarValue::from("ae")),
449                ColumnarValue::Scalar(ScalarValue::from("AE"))
450            ],
451            Ok(Some("cAfé")),
452            &str,
453            Utf8,
454            StringArray
455        );
456
457        #[cfg(not(feature = "unicode_expressions"))]
458        test_function!(
459            TranslateFunc::new(),
460            vec![
461                ColumnarValue::Scalar(ScalarValue::from("12345")),
462                ColumnarValue::Scalar(ScalarValue::from("143")),
463                ColumnarValue::Scalar(ScalarValue::from("ax")),
464            ],
465            internal_err!(
466                "function translate requires compilation with feature flag: unicode_expressions."
467            ),
468            &str,
469            Utf8,
470            StringArray
471        );
472
473        Ok(())
474    }
475}