Skip to main content

datafusion_functions/unicode/
translate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    ArrayAccessor, ArrayIter, ArrayRef, AsArray, LargeStringBuilder, StringBuilder,
20    StringLikeArrayBuilder, StringViewBuilder,
21};
22use arrow::datatypes::DataType;
23use datafusion_common::HashMap;
24
25use crate::utils::make_scalar_function;
26use datafusion_common::{Result, exec_err};
27use datafusion_expr::TypeSignature::Exact;
28use datafusion_expr::{
29    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30    Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35    doc_section(label = "String Functions"),
36    description = "Performs character-wise substitution based on a mapping.",
37    syntax_example = "translate(str, from, to)",
38    sql_example = r#"```sql
39> select translate('twice', 'wic', 'her');
40+--------------------------------------------------+
41| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) |
42+--------------------------------------------------+
43| there                                            |
44+--------------------------------------------------+
45```"#,
46    standard_argument(name = "str", prefix = "String"),
47    argument(name = "from", description = "The characters to be replaced."),
48    argument(
49        name = "to",
50        description = "The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping."
51    )
52)]
53#[derive(Debug, PartialEq, Eq, Hash)]
54pub struct TranslateFunc {
55    signature: Signature,
56}
57
58impl Default for TranslateFunc {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl TranslateFunc {
65    pub fn new() -> Self {
66        use DataType::*;
67        Self {
68            signature: Signature::one_of(
69                vec![
70                    Exact(vec![Utf8View, Utf8, Utf8]),
71                    Exact(vec![Utf8, Utf8, Utf8]),
72                    Exact(vec![LargeUtf8, Utf8, Utf8]),
73                ],
74                Volatility::Immutable,
75            ),
76        }
77    }
78}
79
80impl ScalarUDFImpl for TranslateFunc {
81    fn name(&self) -> &str {
82        "translate"
83    }
84
85    fn signature(&self) -> &Signature {
86        &self.signature
87    }
88
89    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
90        Ok(arg_types[0].clone())
91    }
92
93    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
94        // When from and to are scalars, pre-build the translation map once
95        if let (Some(from_str), Some(to_str)) = (
96            try_as_scalar_str(&args.args[1]),
97            try_as_scalar_str(&args.args[2]),
98        ) {
99            let to_chars: Vec<char> = to_str.chars().collect();
100
101            let mut from_map: HashMap<char, usize> = HashMap::new();
102            for (index, c) in from_str.chars().enumerate() {
103                from_map.entry(c).or_insert(index);
104            }
105
106            let ascii_table = build_ascii_translate_table(from_str, to_str);
107
108            let string_array = args.args[0].to_array_of_size(args.number_rows)?;
109            let len = string_array.len();
110
111            let result = match string_array.data_type() {
112                DataType::Utf8View => {
113                    let arr = string_array.as_string_view();
114                    let builder = StringViewBuilder::with_capacity(len);
115                    translate_with_map(
116                        arr,
117                        &from_map,
118                        &to_chars,
119                        ascii_table.as_ref(),
120                        builder,
121                    )
122                }
123                DataType::Utf8 => {
124                    let arr = string_array.as_string::<i32>();
125                    let builder =
126                        StringBuilder::with_capacity(len, arr.value_data().len());
127                    translate_with_map(
128                        arr,
129                        &from_map,
130                        &to_chars,
131                        ascii_table.as_ref(),
132                        builder,
133                    )
134                }
135                DataType::LargeUtf8 => {
136                    let arr = string_array.as_string::<i64>();
137                    let builder =
138                        LargeStringBuilder::with_capacity(len, arr.value_data().len());
139                    translate_with_map(
140                        arr,
141                        &from_map,
142                        &to_chars,
143                        ascii_table.as_ref(),
144                        builder,
145                    )
146                }
147                other => {
148                    return exec_err!(
149                        "Unsupported data type {other:?} for function translate"
150                    );
151                }
152            }?;
153
154            return Ok(ColumnarValue::Array(result));
155        }
156
157        make_scalar_function(invoke_translate, vec![])(&args.args)
158    }
159
160    fn documentation(&self) -> Option<&Documentation> {
161        self.doc()
162    }
163}
164
165use super::common::try_as_scalar_str;
166
167fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
168    let len = args[0].len();
169    match args[0].data_type() {
170        DataType::Utf8View => {
171            let string_array = args[0].as_string_view();
172            let from_array = args[1].as_string::<i32>();
173            let to_array = args[2].as_string::<i32>();
174            let builder = StringViewBuilder::with_capacity(len);
175            translate(string_array, from_array, to_array, builder)
176        }
177        DataType::Utf8 => {
178            let string_array = args[0].as_string::<i32>();
179            let from_array = args[1].as_string::<i32>();
180            let to_array = args[2].as_string::<i32>();
181            let builder =
182                StringBuilder::with_capacity(len, string_array.value_data().len());
183            translate(string_array, from_array, to_array, builder)
184        }
185        DataType::LargeUtf8 => {
186            let string_array = args[0].as_string::<i64>();
187            let from_array = args[1].as_string::<i32>();
188            let to_array = args[2].as_string::<i32>();
189            let builder =
190                LargeStringBuilder::with_capacity(len, string_array.value_data().len());
191            translate(string_array, from_array, to_array, builder)
192        }
193        other => {
194            exec_err!("Unsupported data type {other:?} for function translate")
195        }
196    }
197}
198
199/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted.
200/// translate('12345', '143', 'ax') = 'a2x5'
201fn translate<'a, V, B, O>(
202    string_array: V,
203    from_array: B,
204    to_array: B,
205    mut builder: O,
206) -> Result<ArrayRef>
207where
208    V: ArrayAccessor<Item = &'a str>,
209    B: ArrayAccessor<Item = &'a str>,
210    O: StringLikeArrayBuilder,
211{
212    let string_array_iter = ArrayIter::new(string_array);
213    let from_array_iter = ArrayIter::new(from_array);
214    let to_array_iter = ArrayIter::new(to_array);
215
216    let mut from_map: HashMap<char, usize> = HashMap::new();
217    let mut to_chars: Vec<char> = Vec::new();
218    let mut result_buf = String::new();
219
220    for ((string, from), to) in string_array_iter.zip(from_array_iter).zip(to_array_iter)
221    {
222        match (string, from, to) {
223            (Some(string), Some(from), Some(to)) => {
224                from_map.clear();
225                to_chars.clear();
226                result_buf.clear();
227
228                for (index, c) in from.chars().enumerate() {
229                    from_map.entry(c).or_insert(index);
230                }
231
232                to_chars.extend(to.chars());
233
234                translate_char_by_char(string, &from_map, &to_chars, &mut result_buf);
235
236                builder.append_value(&result_buf);
237            }
238            _ => builder.append_null(),
239        }
240    }
241
242    Ok(builder.finish())
243}
244
245/// Translate `input` character-by-character using `from_map` and `to_chars`,
246/// appending the result to `buf`.
247#[inline]
248fn translate_char_by_char(
249    input: &str,
250    from_map: &HashMap<char, usize>,
251    to_chars: &[char],
252    buf: &mut String,
253) {
254    for c in input.chars() {
255        match from_map.get(&c) {
256            Some(n) => {
257                if let Some(&replacement) = to_chars.get(*n) {
258                    buf.push(replacement);
259                }
260            }
261            None => buf.push(c),
262        }
263    }
264}
265
266/// Sentinel value in the ASCII translate table indicating the character should
267/// be deleted (the `from` character has no corresponding `to` character).  Any
268/// value > 127 works since valid ASCII is 0–127.
269const ASCII_DELETE: u8 = 0xFF;
270
271/// If `from` and `to` are both ASCII, build a fixed-size lookup table for
272/// translation. Each entry maps an input byte to its replacement byte, or to
273/// [`ASCII_DELETE`] if the character should be removed.  Returns `None` if
274/// either string contains non-ASCII characters.
275fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> {
276    if !from.is_ascii() || !to.is_ascii() {
277        return None;
278    }
279    let mut table = [0u8; 128];
280    for i in 0..128u8 {
281        table[i as usize] = i;
282    }
283    let to_bytes = to.as_bytes();
284    let mut seen = [false; 128];
285    for (i, from_byte) in from.bytes().enumerate() {
286        let idx = from_byte as usize;
287        if !seen[idx] {
288            seen[idx] = true;
289            if i < to_bytes.len() {
290                table[idx] = to_bytes[i];
291            } else {
292                table[idx] = ASCII_DELETE;
293            }
294        }
295    }
296    Some(table)
297}
298
299/// Optimized translate for constant `from` and `to` arguments: uses a pre-built
300/// translation map instead of rebuilding it for every row.  When an ASCII byte
301/// lookup table is provided, ASCII input rows use the lookup table; non-ASCII
302/// inputs fall back to the char-based map.
303fn translate_with_map<'a, V, O>(
304    string_array: V,
305    from_map: &HashMap<char, usize>,
306    to_chars: &[char],
307    ascii_table: Option<&[u8; 128]>,
308    mut builder: O,
309) -> Result<ArrayRef>
310where
311    V: ArrayAccessor<Item = &'a str>,
312    O: StringLikeArrayBuilder,
313{
314    let mut result_buf = String::new();
315    let mut ascii_buf: Vec<u8> = Vec::new();
316
317    for string in ArrayIter::new(string_array) {
318        match string {
319            Some(s) => {
320                // Fast path: byte-level table lookup for ASCII strings
321                if let Some(table) = ascii_table
322                    && s.is_ascii()
323                {
324                    ascii_buf.clear();
325                    for &b in s.as_bytes() {
326                        let mapped = table[b as usize];
327                        if mapped != ASCII_DELETE {
328                            ascii_buf.push(mapped);
329                        }
330                    }
331                    // SAFETY: all bytes are ASCII, hence valid UTF-8.
332                    builder.append_value(unsafe {
333                        std::str::from_utf8_unchecked(&ascii_buf)
334                    });
335                } else {
336                    result_buf.clear();
337                    translate_char_by_char(s, from_map, to_chars, &mut result_buf);
338                    builder.append_value(&result_buf);
339                }
340            }
341            None => builder.append_null(),
342        }
343    }
344
345    Ok(builder.finish())
346}
347
348#[cfg(test)]
349mod tests {
350    use arrow::array::{Array, StringArray, StringViewArray};
351    use arrow::datatypes::DataType::{Utf8, Utf8View};
352
353    use datafusion_common::{Result, ScalarValue};
354    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
355
356    use crate::unicode::translate::TranslateFunc;
357    use crate::utils::test::test_function;
358
359    #[test]
360    fn test_functions() -> Result<()> {
361        test_function!(
362            TranslateFunc::new(),
363            vec![
364                ColumnarValue::Scalar(ScalarValue::from("12345")),
365                ColumnarValue::Scalar(ScalarValue::from("143")),
366                ColumnarValue::Scalar(ScalarValue::from("ax"))
367            ],
368            Ok(Some("a2x5")),
369            &str,
370            Utf8,
371            StringArray
372        );
373        test_function!(
374            TranslateFunc::new(),
375            vec![
376                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
377                ColumnarValue::Scalar(ScalarValue::from("143")),
378                ColumnarValue::Scalar(ScalarValue::from("ax"))
379            ],
380            Ok(None),
381            &str,
382            Utf8,
383            StringArray
384        );
385        test_function!(
386            TranslateFunc::new(),
387            vec![
388                ColumnarValue::Scalar(ScalarValue::from("12345")),
389                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
390                ColumnarValue::Scalar(ScalarValue::from("ax"))
391            ],
392            Ok(None),
393            &str,
394            Utf8,
395            StringArray
396        );
397        test_function!(
398            TranslateFunc::new(),
399            vec![
400                ColumnarValue::Scalar(ScalarValue::from("12345")),
401                ColumnarValue::Scalar(ScalarValue::from("143")),
402                ColumnarValue::Scalar(ScalarValue::Utf8(None))
403            ],
404            Ok(None),
405            &str,
406            Utf8,
407            StringArray
408        );
409        test_function!(
410            TranslateFunc::new(),
411            vec![
412                ColumnarValue::Scalar(ScalarValue::from("abcabc")),
413                ColumnarValue::Scalar(ScalarValue::from("aa")),
414                ColumnarValue::Scalar(ScalarValue::from("de"))
415            ],
416            Ok(Some("dbcdbc")),
417            &str,
418            Utf8,
419            StringArray
420        );
421        test_function!(
422            TranslateFunc::new(),
423            vec![
424                ColumnarValue::Scalar(ScalarValue::from("é2íñ5")),
425                ColumnarValue::Scalar(ScalarValue::from("éñí")),
426                ColumnarValue::Scalar(ScalarValue::from("óü")),
427            ],
428            Ok(Some("ó2ü5")),
429            &str,
430            Utf8,
431            StringArray
432        );
433        // Non-ASCII input with ASCII scalar from/to: exercises the
434        // char-based fallback within translate_with_map.
435        test_function!(
436            TranslateFunc::new(),
437            vec![
438                ColumnarValue::Scalar(ScalarValue::from("café")),
439                ColumnarValue::Scalar(ScalarValue::from("ae")),
440                ColumnarValue::Scalar(ScalarValue::from("AE"))
441            ],
442            Ok(Some("cAfé")),
443            &str,
444            Utf8,
445            StringArray
446        );
447        // Utf8View input should produce Utf8View output
448        test_function!(
449            TranslateFunc::new(),
450            vec![
451                ColumnarValue::Scalar(ScalarValue::Utf8View(Some("12345".into()))),
452                ColumnarValue::Scalar(ScalarValue::from("143")),
453                ColumnarValue::Scalar(ScalarValue::from("ax"))
454            ],
455            Ok(Some("a2x5")),
456            &str,
457            Utf8View,
458            StringViewArray
459        );
460        // Null Utf8View input
461        test_function!(
462            TranslateFunc::new(),
463            vec![
464                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
465                ColumnarValue::Scalar(ScalarValue::from("143")),
466                ColumnarValue::Scalar(ScalarValue::from("ax"))
467            ],
468            Ok(None),
469            &str,
470            Utf8View,
471            StringViewArray
472        );
473        // Non-ASCII Utf8View input
474        test_function!(
475            TranslateFunc::new(),
476            vec![
477                ColumnarValue::Scalar(ScalarValue::Utf8View(Some("é2íñ5".into()))),
478                ColumnarValue::Scalar(ScalarValue::from("éñí")),
479                ColumnarValue::Scalar(ScalarValue::from("óü"))
480            ],
481            Ok(Some("ó2ü5")),
482            &str,
483            Utf8View,
484            StringViewArray
485        );
486
487        #[cfg(not(feature = "unicode_expressions"))]
488        test_function!(
489            TranslateFunc::new(),
490            vec![
491                ColumnarValue::Scalar(ScalarValue::from("12345")),
492                ColumnarValue::Scalar(ScalarValue::from("143")),
493                ColumnarValue::Scalar(ScalarValue::from("ax")),
494            ],
495            internal_err!(
496                "function translate requires compilation with feature flag: unicode_expressions."
497            ),
498            &str,
499            Utf8,
500            StringArray
501        );
502
503        Ok(())
504    }
505}