Skip to main content

datafusion_spark/function/math/
hex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::str::from_utf8_unchecked;
20use std::sync::Arc;
21
22use arrow::array::{Array, ArrayRef, StringBuilder};
23use arrow::datatypes::DataType;
24use arrow::{
25    array::{as_dictionary_array, as_largestring_array, as_string_array},
26    datatypes::Int32Type,
27};
28use datafusion_common::cast::as_large_binary_array;
29use datafusion_common::cast::as_string_view_array;
30use datafusion_common::types::{NativeType, logical_int64, logical_string};
31use datafusion_common::utils::take_function_args;
32use datafusion_common::{
33    DataFusionError,
34    cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
35    exec_err,
36};
37use datafusion_expr::{
38    Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
39    TypeSignatureClass, Volatility,
40};
41/// <https://spark.apache.org/docs/latest/api/sql/index.html#hex>
42#[derive(Debug, PartialEq, Eq, Hash)]
43pub struct SparkHex {
44    signature: Signature,
45    aliases: Vec<String>,
46}
47
48impl Default for SparkHex {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl SparkHex {
55    pub fn new() -> Self {
56        let int64 = Coercion::new_implicit(
57            TypeSignatureClass::Native(logical_int64()),
58            vec![TypeSignatureClass::Numeric],
59            NativeType::Int64,
60        );
61
62        let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
63
64        let binary = Coercion::new_exact(TypeSignatureClass::Binary);
65
66        let variants = vec![
67            // accepts numeric types
68            TypeSignature::Coercible(vec![int64]),
69            // accepts string types (Utf8, Utf8View, LargeUtf8)
70            TypeSignature::Coercible(vec![string]),
71            // accepts binary types (Binary, FixedSizeBinary, LargeBinary)
72            TypeSignature::Coercible(vec![binary]),
73        ];
74
75        Self {
76            signature: Signature::one_of(variants, Volatility::Immutable),
77            aliases: vec![],
78        }
79    }
80}
81
82impl ScalarUDFImpl for SparkHex {
83    fn as_any(&self) -> &dyn Any {
84        self
85    }
86
87    fn name(&self) -> &str {
88        "hex"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
96        Ok(match &arg_types[0] {
97            DataType::Dictionary(key_type, _) => {
98                DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8))
99            }
100            _ => DataType::Utf8,
101        })
102    }
103
104    fn invoke_with_args(
105        &self,
106        args: ScalarFunctionArgs,
107    ) -> datafusion_common::Result<ColumnarValue> {
108        spark_hex(&args.args)
109    }
110
111    fn aliases(&self) -> &[String] {
112        &self.aliases
113    }
114}
115
116/// Hex encoding lookup tables for fast byte-to-hex conversion
117const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef";
118const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF";
119
120#[inline]
121fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] {
122    if num == 0 {
123        return b"0";
124    }
125
126    let mut n = num as u64;
127    let mut i = 16;
128    while n != 0 {
129        i -= 1;
130        buffer[i] = HEX_CHARS_UPPER[(n & 0xF) as usize];
131        n >>= 4;
132    }
133    &buffer[i..]
134}
135
136/// Generic hex encoding for byte array types
137fn hex_encode_bytes<'a, I, T>(
138    iter: I,
139    lowercase: bool,
140    len: usize,
141) -> Result<ArrayRef, DataFusionError>
142where
143    I: Iterator<Item = Option<T>>,
144    T: AsRef<[u8]> + 'a,
145{
146    let mut builder = StringBuilder::with_capacity(len, len * 64);
147    let mut buffer = Vec::with_capacity(64);
148    let hex_chars = if lowercase {
149        HEX_CHARS_LOWER
150    } else {
151        HEX_CHARS_UPPER
152    };
153
154    for v in iter {
155        if let Some(b) = v {
156            buffer.clear();
157            let bytes = b.as_ref();
158            for &byte in bytes {
159                buffer.push(hex_chars[(byte >> 4) as usize]);
160                buffer.push(hex_chars[(byte & 0x0f) as usize]);
161            }
162            // SAFETY: buffer contains only ASCII hex digests, which are valid UTF-8
163            unsafe {
164                builder.append_value(from_utf8_unchecked(&buffer));
165            }
166        } else {
167            builder.append_null();
168        }
169    }
170
171    Ok(Arc::new(builder.finish()))
172}
173
174/// Generic hex encoding for int64 type
175fn hex_encode_int64(
176    iter: impl Iterator<Item = Option<i64>>,
177    len: usize,
178) -> Result<ArrayRef, DataFusionError> {
179    let mut builder = StringBuilder::with_capacity(len, len * 16);
180
181    for v in iter {
182        if let Some(num) = v {
183            let mut temp = [0u8; 16];
184            let slice = hex_int64(num, &mut temp);
185            // SAFETY: slice contains only ASCII hex digests, which are valid UTF-8
186            unsafe {
187                builder.append_value(from_utf8_unchecked(slice));
188            }
189        } else {
190            builder.append_null();
191        }
192    }
193
194    Ok(Arc::new(builder.finish()))
195}
196
197/// Spark-compatible `hex` function
198pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
199    compute_hex(args, false)
200}
201
202/// Spark-compatible `sha2` function
203pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
204    compute_hex(args, true)
205}
206
207pub fn compute_hex(
208    args: &[ColumnarValue],
209    lowercase: bool,
210) -> Result<ColumnarValue, DataFusionError> {
211    let input = match take_function_args("hex", args)? {
212        [ColumnarValue::Scalar(value)] => ColumnarValue::Array(value.to_array()?),
213        [ColumnarValue::Array(arr)] => ColumnarValue::Array(Arc::clone(arr)),
214    };
215
216    match &input {
217        ColumnarValue::Array(array) => match array.data_type() {
218            DataType::Int64 => {
219                let array = as_int64_array(array)?;
220                Ok(ColumnarValue::Array(hex_encode_int64(
221                    array.iter(),
222                    array.len(),
223                )?))
224            }
225            DataType::Utf8 => {
226                let array = as_string_array(array);
227                Ok(ColumnarValue::Array(hex_encode_bytes(
228                    array.iter(),
229                    lowercase,
230                    array.len(),
231                )?))
232            }
233            DataType::Utf8View => {
234                let array = as_string_view_array(array)?;
235                Ok(ColumnarValue::Array(hex_encode_bytes(
236                    array.iter(),
237                    lowercase,
238                    array.len(),
239                )?))
240            }
241            DataType::LargeUtf8 => {
242                let array = as_largestring_array(array);
243                Ok(ColumnarValue::Array(hex_encode_bytes(
244                    array.iter(),
245                    lowercase,
246                    array.len(),
247                )?))
248            }
249            DataType::Binary => {
250                let array = as_binary_array(array)?;
251                Ok(ColumnarValue::Array(hex_encode_bytes(
252                    array.iter(),
253                    lowercase,
254                    array.len(),
255                )?))
256            }
257            DataType::LargeBinary => {
258                let array = as_large_binary_array(array)?;
259                Ok(ColumnarValue::Array(hex_encode_bytes(
260                    array.iter(),
261                    lowercase,
262                    array.len(),
263                )?))
264            }
265            DataType::FixedSizeBinary(_) => {
266                let array = as_fixed_size_binary_array(array)?;
267                Ok(ColumnarValue::Array(hex_encode_bytes(
268                    array.iter(),
269                    lowercase,
270                    array.len(),
271                )?))
272            }
273            DataType::Dictionary(key_type, _) => {
274                if **key_type != DataType::Int32 {
275                    return exec_err!(
276                        "hex only supports Int32 dictionary keys, get: {}",
277                        key_type
278                    );
279                }
280
281                let dict = as_dictionary_array::<Int32Type>(&array);
282                let dict_values = dict.values();
283
284                let encoded_values = match dict_values.data_type() {
285                    DataType::Int64 => {
286                        let arr = as_int64_array(dict_values)?;
287                        hex_encode_int64(arr.iter(), arr.len())?
288                    }
289                    DataType::Utf8 => {
290                        let arr = as_string_array(dict_values);
291                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
292                    }
293                    DataType::LargeUtf8 => {
294                        let arr = as_largestring_array(dict_values);
295                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
296                    }
297                    DataType::Utf8View => {
298                        let arr = as_string_view_array(dict_values)?;
299                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
300                    }
301                    DataType::Binary => {
302                        let arr = as_binary_array(dict_values)?;
303                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
304                    }
305                    DataType::LargeBinary => {
306                        let arr = as_large_binary_array(dict_values)?;
307                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
308                    }
309                    DataType::FixedSizeBinary(_) => {
310                        let arr = as_fixed_size_binary_array(dict_values)?;
311                        hex_encode_bytes(arr.iter(), lowercase, arr.len())?
312                    }
313                    _ => {
314                        return exec_err!(
315                            "hex got an unexpected argument type: {}",
316                            dict_values.data_type()
317                        );
318                    }
319                };
320
321                let new_dict = dict.with_values(encoded_values);
322                Ok(ColumnarValue::Array(Arc::new(new_dict)))
323            }
324            _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()),
325        },
326        _ => exec_err!("native hex does not support scalar values at this time"),
327    }
328}
329
330#[cfg(test)]
331mod test {
332    use std::str::from_utf8_unchecked;
333    use std::sync::Arc;
334
335    use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray};
336    use arrow::{
337        array::{
338            BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder,
339            as_string_array,
340        },
341        datatypes::{Int32Type, Int64Type},
342    };
343    use datafusion_common::cast::as_dictionary_array;
344    use datafusion_expr::ColumnarValue;
345
346    #[test]
347    fn test_dictionary_hex_utf8() {
348        let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
349        input_builder.append_value("hi");
350        input_builder.append_value("bye");
351        input_builder.append_null();
352        input_builder.append_value("rust");
353        let input = input_builder.finish();
354
355        let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
356        expected_builder.append_value("6869");
357        expected_builder.append_value("627965");
358        expected_builder.append_null();
359        expected_builder.append_value("72757374");
360        let expected = expected_builder.finish();
361
362        let columnar_value = ColumnarValue::Array(Arc::new(input));
363        let result = super::spark_hex(&[columnar_value]).unwrap();
364
365        let result = match result {
366            ColumnarValue::Array(array) => array,
367            _ => panic!("Expected array"),
368        };
369
370        let result = as_dictionary_array(&result).unwrap();
371
372        assert_eq!(result, &expected);
373    }
374
375    #[test]
376    fn test_dictionary_hex_int64() {
377        let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
378        input_builder.append_value(1);
379        input_builder.append_value(2);
380        input_builder.append_null();
381        input_builder.append_value(3);
382        let input = input_builder.finish();
383
384        let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
385        expected_builder.append_value("1");
386        expected_builder.append_value("2");
387        expected_builder.append_null();
388        expected_builder.append_value("3");
389        let expected = expected_builder.finish();
390
391        let columnar_value = ColumnarValue::Array(Arc::new(input));
392        let result = super::spark_hex(&[columnar_value]).unwrap();
393
394        let result = match result {
395            ColumnarValue::Array(array) => array,
396            _ => panic!("Expected array"),
397        };
398
399        let result = as_dictionary_array(&result).unwrap();
400
401        assert_eq!(result, &expected);
402    }
403
404    #[test]
405    fn test_dictionary_hex_binary() {
406        let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
407        input_builder.append_value("1");
408        input_builder.append_value("j");
409        input_builder.append_null();
410        input_builder.append_value("3");
411        let input = input_builder.finish();
412
413        let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
414        expected_builder.append_value("31");
415        expected_builder.append_value("6A");
416        expected_builder.append_null();
417        expected_builder.append_value("33");
418        let expected = expected_builder.finish();
419
420        let columnar_value = ColumnarValue::Array(Arc::new(input));
421        let result = super::spark_hex(&[columnar_value]).unwrap();
422
423        let result = match result {
424            ColumnarValue::Array(array) => array,
425            _ => panic!("Expected array"),
426        };
427
428        let result = as_dictionary_array(&result).unwrap();
429
430        assert_eq!(result, &expected);
431    }
432
433    #[test]
434    fn test_hex_int64() {
435        let test_cases = vec![(1234, "4D2"), (-1, "FFFFFFFFFFFFFFFF")];
436
437        for (num, expected) in test_cases {
438            let mut cache = [0u8; 16];
439            let slice = super::hex_int64(num, &mut cache);
440
441            unsafe {
442                let result = from_utf8_unchecked(slice);
443                assert_eq!(expected, result);
444            }
445        }
446    }
447
448    #[test]
449    fn test_spark_hex_int64() {
450        let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
451        let columnar_value = ColumnarValue::Array(Arc::new(int_array));
452
453        let result = super::spark_hex(&[columnar_value]).unwrap();
454        let result = match result {
455            ColumnarValue::Array(array) => array,
456            _ => panic!("Expected array"),
457        };
458
459        let string_array = as_string_array(&result);
460        let expected_array = StringArray::from(vec![
461            Some("1".to_string()),
462            Some("2".to_string()),
463            None,
464            Some("3".to_string()),
465        ]);
466
467        assert_eq!(string_array, &expected_array);
468    }
469
470    #[test]
471    fn test_dict_values_null() {
472        let keys = Int32Array::from(vec![Some(0), None, Some(1)]);
473        let vals = Int64Array::from(vec![Some(32), None]);
474        // [32, null, null]
475        let dict = DictionaryArray::new(keys, Arc::new(vals));
476
477        let columnar_value = ColumnarValue::Array(Arc::new(dict));
478        let result = super::spark_hex(&[columnar_value]).unwrap();
479
480        let result = match result {
481            ColumnarValue::Array(array) => array,
482            _ => panic!("Expected array"),
483        };
484
485        let result = as_dictionary_array(&result).unwrap();
486
487        let keys = Int32Array::from(vec![Some(0), None, Some(1)]);
488        let vals = StringArray::from(vec![Some("20"), None]);
489        let expected = DictionaryArray::new(keys, Arc::new(vals));
490
491        assert_eq!(&expected, result);
492    }
493}