Skip to main content

datafusion_spark/function/math/
hex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::str::from_utf8_unchecked;
19use std::sync::Arc;
20
21use arrow::array::{Array, ArrayRef, StringBuilder};
22use arrow::datatypes::DataType;
23use arrow::{
24    array::{as_dictionary_array, as_largestring_array, as_string_array},
25    datatypes::Int32Type,
26};
27use datafusion_common::cast::as_large_binary_array;
28use datafusion_common::cast::as_string_view_array;
29use datafusion_common::types::{NativeType, logical_int64, logical_string};
30use datafusion_common::utils::take_function_args;
31use datafusion_common::{
32    DataFusionError,
33    cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
34    exec_err,
35};
36use datafusion_expr::{
37    Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
38    TypeSignatureClass, Volatility,
39};
40/// <https://spark.apache.org/docs/latest/api/sql/index.html#hex>
41#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkHex {
43    signature: Signature,
44    aliases: Vec<String>,
45}
46
47impl Default for SparkHex {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl SparkHex {
54    pub fn new() -> Self {
55        let int64 = Coercion::new_implicit(
56            TypeSignatureClass::Native(logical_int64()),
57            vec![TypeSignatureClass::Numeric],
58            NativeType::Int64,
59        );
60
61        let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
62
63        let binary = Coercion::new_exact(TypeSignatureClass::Binary);
64
65        let variants = vec![
66            // accepts numeric types
67            TypeSignature::Coercible(vec![int64]),
68            // accepts string types (Utf8, Utf8View, LargeUtf8)
69            TypeSignature::Coercible(vec![string]),
70            // accepts binary types (Binary, FixedSizeBinary, LargeBinary)
71            TypeSignature::Coercible(vec![binary]),
72        ];
73
74        Self {
75            signature: Signature::one_of(variants, Volatility::Immutable),
76            aliases: vec![],
77        }
78    }
79}
80
81impl ScalarUDFImpl for SparkHex {
82    fn name(&self) -> &str {
83        "hex"
84    }
85
86    fn signature(&self) -> &Signature {
87        &self.signature
88    }
89
90    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
91        Ok(match &arg_types[0] {
92            DataType::Dictionary(key_type, _) => {
93                DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8))
94            }
95            _ => DataType::Utf8,
96        })
97    }
98
99    fn invoke_with_args(
100        &self,
101        args: ScalarFunctionArgs,
102    ) -> datafusion_common::Result<ColumnarValue> {
103        spark_hex(&args.args)
104    }
105
106    fn aliases(&self) -> &[String] {
107        &self.aliases
108    }
109}
110
111/// Hex encoding lookup tables for fast byte-to-hex conversion.
112///
113/// Each entry maps a full byte to its two-character hex encoding so the
114/// hot loop becomes one load + one two-byte extend per input byte instead
115/// of two nibble lookups and two pushes.
116const HEX_CHARS_UPPER_NIBBLES: &[u8; 16] = b"0123456789ABCDEF";
117const HEX_CHARS_LOWER_NIBBLES: &[u8; 16] = b"0123456789abcdef";
118
119const HEX_LOOKUP_UPPER: [[u8; 2]; 256] = build_hex_lookup(HEX_CHARS_UPPER_NIBBLES);
120const HEX_LOOKUP_LOWER: [[u8; 2]; 256] = build_hex_lookup(HEX_CHARS_LOWER_NIBBLES);
121
122const fn build_hex_lookup(nibbles: &[u8; 16]) -> [[u8; 2]; 256] {
123    let mut table = [[0u8; 2]; 256];
124    let mut i = 0;
125    while i < 256 {
126        table[i][0] = nibbles[(i >> 4) & 0xF];
127        table[i][1] = nibbles[i & 0xF];
128        i += 1;
129    }
130    table
131}
132
133#[inline]
134fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] {
135    if num == 0 {
136        return b"0";
137    }
138
139    // Walk the value two nibbles (one full byte) at a time. The buffer is
140    // filled from the right so the high-order nibbles end up first; the
141    // returned slice trims leading zeros automatically.
142    let mut n = num as u64;
143    let mut i = 16;
144    while n >= 0x10 {
145        i -= 2;
146        let pair = HEX_LOOKUP_UPPER[(n & 0xFF) as usize];
147        buffer[i] = pair[0];
148        buffer[i + 1] = pair[1];
149        n >>= 8;
150    }
151    if n > 0 {
152        // Single remaining high nibble (value 0x1..=0xF).
153        i -= 1;
154        buffer[i] = HEX_CHARS_UPPER_NIBBLES[n as usize];
155    }
156    &buffer[i..]
157}
158
159/// Generic hex encoding for byte array types
160fn hex_encode_bytes<'a, I, T>(
161    iter: I,
162    lowercase: bool,
163    len: usize,
164) -> Result<ArrayRef, DataFusionError>
165where
166    I: Iterator<Item = Option<T>>,
167    T: AsRef<[u8]> + 'a,
168{
169    let mut builder = StringBuilder::with_capacity(len, len * 64);
170    let mut buffer = Vec::with_capacity(64);
171    let lookup = if lowercase {
172        &HEX_LOOKUP_LOWER
173    } else {
174        &HEX_LOOKUP_UPPER
175    };
176
177    for v in iter {
178        if let Some(b) = v {
179            let bytes = b.as_ref();
180            buffer.clear();
181            buffer.reserve(bytes.len() * 2);
182            for &byte in bytes {
183                buffer.extend_from_slice(&lookup[byte as usize]);
184            }
185            // SAFETY: buffer contains only ASCII hex digits, which are valid UTF-8.
186            unsafe {
187                builder.append_value(from_utf8_unchecked(&buffer));
188            }
189        } else {
190            builder.append_null();
191        }
192    }
193
194    Ok(Arc::new(builder.finish()))
195}
196
197/// Generic hex encoding for int64 type
198fn hex_encode_int64(
199    iter: impl Iterator<Item = Option<i64>>,
200    len: usize,
201) -> Result<ArrayRef, DataFusionError> {
202    let mut builder = StringBuilder::with_capacity(len, len * 16);
203
204    for v in iter {
205        if let Some(num) = v {
206            let mut temp = [0u8; 16];
207            let slice = hex_int64(num, &mut temp);
208            // SAFETY: slice contains only ASCII hex digests, which are valid UTF-8
209            unsafe {
210                builder.append_value(from_utf8_unchecked(slice));
211            }
212        } else {
213            builder.append_null();
214        }
215    }
216
217    Ok(Arc::new(builder.finish()))
218}
219
220/// Spark-compatible `hex` function
221pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
222    compute_hex(args, false)
223}
224
225/// Spark-compatible `sha2` function
226pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
227    compute_hex(args, true)
228}
229
230pub fn compute_hex(
231    args: &[ColumnarValue],
232    lowercase: bool,
233) -> Result<ColumnarValue, DataFusionError> {
234    let input = match take_function_args("hex", args)? {
235        [ColumnarValue::Scalar(value)] => ColumnarValue::Array(value.to_array()?),
236        [ColumnarValue::Array(arr)] => ColumnarValue::Array(Arc::clone(arr)),
237    };
238
239    match &input {
240        ColumnarValue::Array(array) => match array.data_type() {
241            DataType::Int64 => {
242                let array = as_int64_array(array)?;
243                Ok(ColumnarValue::Array(hex_encode_int64(
244                    array.iter(),
245                    array.len(),
246                )?))
247            }
248            DataType::Utf8 => {
249                let array = as_string_array(array);
250                Ok(ColumnarValue::Array(hex_encode_bytes(
251                    array.iter(),
252                    lowercase,
253                    array.len(),
254                )?))
255            }
256            DataType::Utf8View => {
257                let array = as_string_view_array(array)?;
258                Ok(ColumnarValue::Array(hex_encode_bytes(
259                    array.iter(),
260                    lowercase,
261                    array.len(),
262                )?))
263            }
264            DataType::LargeUtf8 => {
265                let array = as_largestring_array(array);
266                Ok(ColumnarValue::Array(hex_encode_bytes(
267                    array.iter(),
268                    lowercase,
269                    array.len(),
270                )?))
271            }
272            DataType::Binary => {
273                let array = as_binary_array(array)?;
274                Ok(ColumnarValue::Array(hex_encode_bytes(
275                    array.iter(),
276                    lowercase,
277                    array.len(),
278                )?))
279            }
280            DataType::LargeBinary => {
281                let array = as_large_binary_array(array)?;
282                Ok(ColumnarValue::Array(hex_encode_bytes(
283                    array.iter(),
284                    lowercase,
285                    array.len(),
286                )?))
287            }
288            DataType::FixedSizeBinary(_) => {
289                let array = as_fixed_size_binary_array(array)?;
290                Ok(ColumnarValue::Array(hex_encode_bytes(
291                    array.iter(),
292                    lowercase,
293                    array.len(),
294                )?))
295            }
296            DataType::Dictionary(key_type, _) => {
297                if **key_type != DataType::Int32 {
298                    return exec_err!(
299                        "hex only supports Int32 dictionary keys, get: {}",
300                        key_type
301                    );
302                }
303
304                let dict = as_dictionary_array::<Int32Type>(&array);
305                let dict_values = dict.values();
306
307                let encoded_values = match dict_values.data_type() {
308                    DataType::Int64 => {
309                        let arr = as_int64_array(dict_values)?;
310                        hex_encode_int64(arr.iter(), arr.len())?
311                    }
312                    DataType::Utf8 => {
313                        let arr = as_string_array(dict_values);
314                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
315                    }
316                    DataType::LargeUtf8 => {
317                        let arr = as_largestring_array(dict_values);
318                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
319                    }
320                    DataType::Utf8View => {
321                        let arr = as_string_view_array(dict_values)?;
322                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
323                    }
324                    DataType::Binary => {
325                        let arr = as_binary_array(dict_values)?;
326                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
327                    }
328                    DataType::LargeBinary => {
329                        let arr = as_large_binary_array(dict_values)?;
330                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
331                    }
332                    DataType::FixedSizeBinary(_) => {
333                        let arr = as_fixed_size_binary_array(dict_values)?;
334                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
335                    }
336                    _ => {
337                        return exec_err!(
338                            "hex got an unexpected argument type: {}",
339                            dict_values.data_type()
340                        );
341                    }
342                };
343
344                let new_dict = dict.with_values(encoded_values);
345                Ok(ColumnarValue::Array(Arc::new(new_dict)))
346            }
347            _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()),
348        },
349        _ => exec_err!("native hex does not support scalar values at this time"),
350    }
351}
352
353#[cfg(test)]
354mod test {
355    use std::str::from_utf8_unchecked;
356    use std::sync::Arc;
357
358    use arrow::array::{
359        BinaryArray, DictionaryArray, Int32Array, Int64Array, StringArray,
360    };
361    use arrow::{
362        array::{
363            BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder,
364            as_string_array,
365        },
366        datatypes::{Int32Type, Int64Type},
367    };
368    use datafusion_common::cast::as_dictionary_array;
369    use datafusion_expr::ColumnarValue;
370
371    #[test]
372    fn test_dictionary_hex_utf8() {
373        let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
374        input_builder.append_value("hi");
375        input_builder.append_value("bye");
376        input_builder.append_null();
377        input_builder.append_value("rust");
378        let input = input_builder.finish();
379
380        let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
381        expected_builder.append_value("6869");
382        expected_builder.append_value("627965");
383        expected_builder.append_null();
384        expected_builder.append_value("72757374");
385        let expected = expected_builder.finish();
386
387        let columnar_value = ColumnarValue::Array(Arc::new(input));
388        let result = super::spark_hex(&[columnar_value]).unwrap();
389
390        let result = match result {
391            ColumnarValue::Array(array) => array,
392            _ => panic!("Expected array"),
393        };
394
395        let result = as_dictionary_array(&result).unwrap();
396
397        assert_eq!(result, &expected);
398    }
399
400    #[test]
401    fn test_dictionary_hex_int64() {
402        let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
403        input_builder.append_value(1);
404        input_builder.append_value(2);
405        input_builder.append_null();
406        input_builder.append_value(3);
407        let input = input_builder.finish();
408
409        let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
410        expected_builder.append_value("1");
411        expected_builder.append_value("2");
412        expected_builder.append_null();
413        expected_builder.append_value("3");
414        let expected = expected_builder.finish();
415
416        let columnar_value = ColumnarValue::Array(Arc::new(input));
417        let result = super::spark_hex(&[columnar_value]).unwrap();
418
419        let result = match result {
420            ColumnarValue::Array(array) => array,
421            _ => panic!("Expected array"),
422        };
423
424        let result = as_dictionary_array(&result).unwrap();
425
426        assert_eq!(result, &expected);
427    }
428
429    #[test]
430    fn test_dictionary_hex_binary() {
431        let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
432        input_builder.append_value("1");
433        input_builder.append_value("j");
434        input_builder.append_null();
435        input_builder.append_value("3");
436        let input = input_builder.finish();
437
438        let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
439        expected_builder.append_value("31");
440        expected_builder.append_value("6A");
441        expected_builder.append_null();
442        expected_builder.append_value("33");
443        let expected = expected_builder.finish();
444
445        let columnar_value = ColumnarValue::Array(Arc::new(input));
446        let result = super::spark_hex(&[columnar_value]).unwrap();
447
448        let result = match result {
449            ColumnarValue::Array(array) => array,
450            _ => panic!("Expected array"),
451        };
452
453        let result = as_dictionary_array(&result).unwrap();
454
455        assert_eq!(result, &expected);
456    }
457
458    #[test]
459    fn test_hex_int64() {
460        let test_cases = vec![
461            (0_i64, "0"),
462            (1, "1"),
463            (15, "F"),
464            (16, "10"),
465            (255, "FF"),
466            (256, "100"),
467            (1234, "4D2"),
468            (i64::MAX, "7FFFFFFFFFFFFFFF"),
469            (i64::MIN, "8000000000000000"),
470            (-1, "FFFFFFFFFFFFFFFF"),
471        ];
472
473        for (num, expected) in test_cases {
474            let mut cache = [0u8; 16];
475            let slice = super::hex_int64(num, &mut cache);
476
477            unsafe {
478                let result = from_utf8_unchecked(slice);
479                assert_eq!(expected, result, "hex_int64({num}) mismatch");
480            }
481        }
482    }
483
484    #[test]
485    fn test_hex_lookup_table_covers_all_bytes() {
486        // Cross-check the precomputed table against an independent encoder
487        // for every possible byte value and both casings.
488        for byte in 0u8..=255 {
489            let upper = format!("{byte:02X}");
490            let lower = format!("{byte:02x}");
491            let upper_pair = super::HEX_LOOKUP_UPPER[byte as usize];
492            let lower_pair = super::HEX_LOOKUP_LOWER[byte as usize];
493            assert_eq!(
494                upper.as_bytes(),
495                &upper_pair,
496                "upper encoding mismatch for byte 0x{byte:02X}"
497            );
498            assert_eq!(
499                lower.as_bytes(),
500                &lower_pair,
501                "lower encoding mismatch for byte 0x{byte:02X}"
502            );
503        }
504    }
505
506    #[test]
507    fn test_spark_hex_binary_round_trip_all_bytes() {
508        // Single-row binary input containing every byte value, encoded in
509        // a single column. Catches per-byte regressions in the bytes path.
510        let payload: Vec<u8> = (0u8..=255).collect();
511        let bin_array = BinaryArray::from(vec![Some(payload.as_slice())]);
512
513        let result =
514            super::spark_hex(&[ColumnarValue::Array(Arc::new(bin_array))]).unwrap();
515        let array = match result {
516            ColumnarValue::Array(array) => array,
517            _ => panic!("Expected array"),
518        };
519        let strings = as_string_array(&array);
520        let mut expected = String::with_capacity(512);
521        for byte in 0u8..=255 {
522            use std::fmt::Write;
523            write!(expected, "{byte:02X}").unwrap();
524        }
525        assert_eq!(strings.value(0), expected);
526    }
527
528    #[test]
529    fn test_spark_hex_int64() {
530        let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
531        let columnar_value = ColumnarValue::Array(Arc::new(int_array));
532
533        let result = super::spark_hex(&[columnar_value]).unwrap();
534        let result = match result {
535            ColumnarValue::Array(array) => array,
536            _ => panic!("Expected array"),
537        };
538
539        let string_array = as_string_array(&result);
540        let expected_array = StringArray::from(vec![
541            Some("1".to_string()),
542            Some("2".to_string()),
543            None,
544            Some("3".to_string()),
545        ]);
546
547        assert_eq!(string_array, &expected_array);
548    }
549
550    #[test]
551    fn test_dict_values_null() {
552        let keys = Int32Array::from(vec![Some(0), None, Some(1)]);
553        let vals = Int64Array::from(vec![Some(32), None]);
554        // [32, null, null]
555        let dict = DictionaryArray::new(keys, Arc::new(vals));
556
557        let columnar_value = ColumnarValue::Array(Arc::new(dict));
558        let result = super::spark_hex(&[columnar_value]).unwrap();
559
560        let result = match result {
561            ColumnarValue::Array(array) => array,
562            _ => panic!("Expected array"),
563        };
564
565        let result = as_dictionary_array(&result).unwrap();
566
567        let keys = Int32Array::from(vec![Some(0), None, Some(1)]);
568        let vals = StringArray::from(vec![Some("20"), None]);
569        let expected = DictionaryArray::new(keys, Arc::new(vals));
570
571        assert_eq!(&expected, result);
572    }
573}