Skip to main content

datafusion_spark/function/math/
hex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{Array, StringArray};
22use arrow::datatypes::DataType;
23use arrow::{
24    array::{as_dictionary_array, as_largestring_array, as_string_array},
25    datatypes::Int32Type,
26};
27use datafusion_common::cast::as_large_binary_array;
28use datafusion_common::cast::as_string_view_array;
29use datafusion_common::types::{NativeType, logical_int64, logical_string};
30use datafusion_common::utils::take_function_args;
31use datafusion_common::{
32    DataFusionError,
33    cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
34    exec_err,
35};
36use datafusion_expr::{
37    Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
38    TypeSignatureClass, Volatility,
39};
40/// <https://spark.apache.org/docs/latest/api/sql/index.html#hex>
41#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkHex {
43    signature: Signature,
44    aliases: Vec<String>,
45}
46
47impl Default for SparkHex {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl SparkHex {
54    pub fn new() -> Self {
55        let int64 = Coercion::new_implicit(
56            TypeSignatureClass::Native(logical_int64()),
57            vec![TypeSignatureClass::Numeric],
58            NativeType::Int64,
59        );
60
61        let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
62
63        let binary = Coercion::new_exact(TypeSignatureClass::Binary);
64
65        let variants = vec![
66            // accepts numeric types
67            TypeSignature::Coercible(vec![int64]),
68            // accepts string types (Utf8, Utf8View, LargeUtf8)
69            TypeSignature::Coercible(vec![string]),
70            // accepts binary types (Binary, FixedSizeBinary, LargeBinary)
71            TypeSignature::Coercible(vec![binary]),
72        ];
73
74        Self {
75            signature: Signature::one_of(variants, Volatility::Immutable),
76            aliases: vec![],
77        }
78    }
79}
80
81impl ScalarUDFImpl for SparkHex {
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn name(&self) -> &str {
87        "hex"
88    }
89
90    fn signature(&self) -> &Signature {
91        &self.signature
92    }
93
94    fn return_type(
95        &self,
96        _arg_types: &[DataType],
97    ) -> datafusion_common::Result<DataType> {
98        Ok(DataType::Utf8)
99    }
100
101    fn invoke_with_args(
102        &self,
103        args: ScalarFunctionArgs,
104    ) -> datafusion_common::Result<ColumnarValue> {
105        spark_hex(&args.args)
106    }
107
108    fn aliases(&self) -> &[String] {
109        &self.aliases
110    }
111}
112
113fn hex_int64(num: i64) -> String {
114    format!("{num:X}")
115}
116
117/// Hex encoding lookup tables for fast byte-to-hex conversion
118const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef";
119const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF";
120
121#[inline]
122fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
123    let bytes = data.as_ref();
124    let mut s = String::with_capacity(bytes.len() * 2);
125    let hex_chars = if lower_case {
126        HEX_CHARS_LOWER
127    } else {
128        HEX_CHARS_UPPER
129    };
130    for &b in bytes {
131        s.push(hex_chars[(b >> 4) as usize] as char);
132        s.push(hex_chars[(b & 0x0f) as usize] as char);
133    }
134    s
135}
136
137#[inline(always)]
138fn hex_bytes<T: AsRef<[u8]>>(
139    bytes: T,
140    lowercase: bool,
141) -> Result<String, std::fmt::Error> {
142    let hex_string = hex_encode(bytes, lowercase);
143    Ok(hex_string)
144}
145
146/// Spark-compatible `hex` function
147pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
148    compute_hex(args, false)
149}
150
151/// Spark-compatible `sha2` function
152pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
153    compute_hex(args, true)
154}
155
156pub fn compute_hex(
157    args: &[ColumnarValue],
158    lowercase: bool,
159) -> Result<ColumnarValue, DataFusionError> {
160    let input = match take_function_args("hex", args)? {
161        [ColumnarValue::Scalar(value)] => ColumnarValue::Array(value.to_array()?),
162        [ColumnarValue::Array(arr)] => ColumnarValue::Array(Arc::clone(arr)),
163    };
164
165    match &input {
166        ColumnarValue::Array(array) => match array.data_type() {
167            DataType::Int64 => {
168                let array = as_int64_array(array)?;
169
170                let hexed_array: StringArray =
171                    array.iter().map(|v| v.map(hex_int64)).collect();
172
173                Ok(ColumnarValue::Array(Arc::new(hexed_array)))
174            }
175            DataType::Utf8 => {
176                let array = as_string_array(array);
177
178                let hexed: StringArray = array
179                    .iter()
180                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
181                    .collect::<Result<_, _>>()?;
182
183                Ok(ColumnarValue::Array(Arc::new(hexed)))
184            }
185            DataType::Utf8View => {
186                let array = as_string_view_array(array)?;
187
188                let hexed: StringArray = array
189                    .iter()
190                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
191                    .collect::<Result<_, _>>()?;
192
193                Ok(ColumnarValue::Array(Arc::new(hexed)))
194            }
195            DataType::LargeUtf8 => {
196                let array = as_largestring_array(array);
197
198                let hexed: StringArray = array
199                    .iter()
200                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
201                    .collect::<Result<_, _>>()?;
202
203                Ok(ColumnarValue::Array(Arc::new(hexed)))
204            }
205            DataType::Binary => {
206                let array = as_binary_array(array)?;
207
208                let hexed: StringArray = array
209                    .iter()
210                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
211                    .collect::<Result<_, _>>()?;
212
213                Ok(ColumnarValue::Array(Arc::new(hexed)))
214            }
215            DataType::LargeBinary => {
216                let array = as_large_binary_array(array)?;
217
218                let hexed: StringArray = array
219                    .iter()
220                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
221                    .collect::<Result<_, _>>()?;
222
223                Ok(ColumnarValue::Array(Arc::new(hexed)))
224            }
225            DataType::FixedSizeBinary(_) => {
226                let array = as_fixed_size_binary_array(array)?;
227
228                let hexed: StringArray = array
229                    .iter()
230                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
231                    .collect::<Result<_, _>>()?;
232
233                Ok(ColumnarValue::Array(Arc::new(hexed)))
234            }
235            DataType::Dictionary(_, value_type) => {
236                let dict = as_dictionary_array::<Int32Type>(&array);
237
238                let values = match **value_type {
239                    DataType::Int64 => as_int64_array(dict.values())?
240                        .iter()
241                        .map(|v| v.map(hex_int64))
242                        .collect::<Vec<_>>(),
243                    DataType::Utf8 => as_string_array(dict.values())
244                        .iter()
245                        .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
246                        .collect::<Result<_, _>>()?,
247                    DataType::Binary => as_binary_array(dict.values())?
248                        .iter()
249                        .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
250                        .collect::<Result<_, _>>()?,
251                    _ => exec_err!(
252                        "hex got an unexpected argument type: {}",
253                        array.data_type()
254                    )?,
255                };
256
257                let new_values: Vec<Option<String>> = dict
258                    .keys()
259                    .iter()
260                    .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None))
261                    .collect();
262
263                let string_array_values = StringArray::from(new_values);
264
265                Ok(ColumnarValue::Array(Arc::new(string_array_values)))
266            }
267            _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()),
268        },
269        _ => exec_err!("native hex does not support scalar values at this time"),
270    }
271}
272
273#[cfg(test)]
274mod test {
275    use std::sync::Arc;
276
277    use arrow::array::{Int64Array, StringArray};
278    use arrow::{
279        array::{
280            BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder,
281            StringDictionaryBuilder, as_string_array,
282        },
283        datatypes::{Int32Type, Int64Type},
284    };
285    use datafusion_expr::ColumnarValue;
286
287    #[test]
288    fn test_dictionary_hex_utf8() {
289        let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
290        input_builder.append_value("hi");
291        input_builder.append_value("bye");
292        input_builder.append_null();
293        input_builder.append_value("rust");
294        let input = input_builder.finish();
295
296        let mut string_builder = StringBuilder::new();
297        string_builder.append_value("6869");
298        string_builder.append_value("627965");
299        string_builder.append_null();
300        string_builder.append_value("72757374");
301        let expected = string_builder.finish();
302
303        let columnar_value = ColumnarValue::Array(Arc::new(input));
304        let result = super::spark_hex(&[columnar_value]).unwrap();
305
306        let result = match result {
307            ColumnarValue::Array(array) => array,
308            _ => panic!("Expected array"),
309        };
310
311        let result = as_string_array(&result);
312
313        assert_eq!(result, &expected);
314    }
315
316    #[test]
317    fn test_dictionary_hex_int64() {
318        let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
319        input_builder.append_value(1);
320        input_builder.append_value(2);
321        input_builder.append_null();
322        input_builder.append_value(3);
323        let input = input_builder.finish();
324
325        let mut string_builder = StringBuilder::new();
326        string_builder.append_value("1");
327        string_builder.append_value("2");
328        string_builder.append_null();
329        string_builder.append_value("3");
330        let expected = string_builder.finish();
331
332        let columnar_value = ColumnarValue::Array(Arc::new(input));
333        let result = super::spark_hex(&[columnar_value]).unwrap();
334
335        let result = match result {
336            ColumnarValue::Array(array) => array,
337            _ => panic!("Expected array"),
338        };
339
340        let result = as_string_array(&result);
341
342        assert_eq!(result, &expected);
343    }
344
345    #[test]
346    fn test_dictionary_hex_binary() {
347        let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
348        input_builder.append_value("1");
349        input_builder.append_value("j");
350        input_builder.append_null();
351        input_builder.append_value("3");
352        let input = input_builder.finish();
353
354        let mut expected_builder = StringBuilder::new();
355        expected_builder.append_value("31");
356        expected_builder.append_value("6A");
357        expected_builder.append_null();
358        expected_builder.append_value("33");
359        let expected = expected_builder.finish();
360
361        let columnar_value = ColumnarValue::Array(Arc::new(input));
362        let result = super::spark_hex(&[columnar_value]).unwrap();
363
364        let result = match result {
365            ColumnarValue::Array(array) => array,
366            _ => panic!("Expected array"),
367        };
368
369        let result = as_string_array(&result);
370
371        assert_eq!(result, &expected);
372    }
373
374    #[test]
375    fn test_hex_int64() {
376        let num = 1234;
377        let hexed = super::hex_int64(num);
378        assert_eq!(hexed, "4D2".to_string());
379
380        let num = -1;
381        let hexed = super::hex_int64(num);
382        assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string());
383    }
384
385    #[test]
386    fn test_spark_hex_int64() {
387        let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
388        let columnar_value = ColumnarValue::Array(Arc::new(int_array));
389
390        let result = super::spark_hex(&[columnar_value]).unwrap();
391        let result = match result {
392            ColumnarValue::Array(array) => array,
393            _ => panic!("Expected array"),
394        };
395
396        let string_array = as_string_array(&result);
397        let expected_array = StringArray::from(vec![
398            Some("1".to_string()),
399            Some("2".to_string()),
400            None,
401            Some("3".to_string()),
402        ]);
403
404        assert_eq!(string_array, &expected_array);
405    }
406}