datafusion_spark/function/math/
hex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::function::error_utils::{
22    invalid_arg_count_exec_err, unsupported_data_type_exec_err,
23};
24use arrow::array::{Array, StringArray};
25use arrow::datatypes::DataType;
26use arrow::{
27    array::{as_dictionary_array, as_largestring_array, as_string_array},
28    datatypes::Int32Type,
29};
30use datafusion_common::cast::as_string_view_array;
31use datafusion_common::{
32    cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
33    exec_err, DataFusionError,
34};
35use datafusion_expr::Signature;
36use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
37use std::fmt::Write;
38
39/// <https://spark.apache.org/docs/latest/api/sql/index.html#hex>
40#[derive(Debug, PartialEq, Eq, Hash)]
41pub struct SparkHex {
42    signature: Signature,
43    aliases: Vec<String>,
44}
45
46impl Default for SparkHex {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl SparkHex {
53    pub fn new() -> Self {
54        Self {
55            signature: Signature::user_defined(Volatility::Immutable),
56            aliases: vec![],
57        }
58    }
59}
60
61impl ScalarUDFImpl for SparkHex {
62    fn as_any(&self) -> &dyn Any {
63        self
64    }
65
66    fn name(&self) -> &str {
67        "hex"
68    }
69
70    fn signature(&self) -> &Signature {
71        &self.signature
72    }
73
74    fn return_type(
75        &self,
76        _arg_types: &[DataType],
77    ) -> datafusion_common::Result<DataType> {
78        Ok(DataType::Utf8)
79    }
80
81    fn invoke_with_args(
82        &self,
83        args: ScalarFunctionArgs,
84    ) -> datafusion_common::Result<ColumnarValue> {
85        spark_hex(&args.args)
86    }
87
88    fn aliases(&self) -> &[String] {
89        &self.aliases
90    }
91
92    fn coerce_types(
93        &self,
94        arg_types: &[DataType],
95    ) -> datafusion_common::Result<Vec<DataType>> {
96        if arg_types.len() != 1 {
97            return Err(invalid_arg_count_exec_err("hex", (1, 1), arg_types.len()));
98        }
99        match &arg_types[0] {
100            DataType::Int64
101            | DataType::Utf8
102            | DataType::Utf8View
103            | DataType::LargeUtf8
104            | DataType::Binary
105            | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]),
106            DataType::Dictionary(key_type, value_type) => match value_type.as_ref() {
107                DataType::Int64
108                | DataType::Utf8
109                | DataType::Utf8View
110                | DataType::LargeUtf8
111                | DataType::Binary
112                | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]),
113                other => {
114                    if other.is_numeric() {
115                        Ok(vec![DataType::Dictionary(
116                            key_type.clone(),
117                            Box::new(DataType::Int64),
118                        )])
119                    } else {
120                        Err(unsupported_data_type_exec_err(
121                            "hex",
122                            "Numeric, String, or Binary",
123                            &arg_types[0],
124                        ))
125                    }
126                }
127            },
128            other => {
129                if other.is_numeric() {
130                    Ok(vec![DataType::Int64])
131                } else {
132                    Err(unsupported_data_type_exec_err(
133                        "hex",
134                        "Numeric, String, or Binary",
135                        &arg_types[0],
136                    ))
137                }
138            }
139        }
140    }
141}
142
143fn hex_int64(num: i64) -> String {
144    format!("{num:X}")
145}
146
147#[inline(always)]
148fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
149    let mut s = String::with_capacity(data.as_ref().len() * 2);
150    if lower_case {
151        for b in data.as_ref() {
152            // Writing to a string never errors, so we can unwrap here.
153            write!(&mut s, "{b:02x}").unwrap();
154        }
155    } else {
156        for b in data.as_ref() {
157            // Writing to a string never errors, so we can unwrap here.
158            write!(&mut s, "{b:02X}").unwrap();
159        }
160    }
161    s
162}
163
164#[inline(always)]
165fn hex_bytes<T: AsRef<[u8]>>(
166    bytes: T,
167    lowercase: bool,
168) -> Result<String, std::fmt::Error> {
169    let hex_string = hex_encode(bytes, lowercase);
170    Ok(hex_string)
171}
172
173/// Spark-compatible `hex` function
174pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
175    compute_hex(args, false)
176}
177
178/// Spark-compatible `sha2` function
179pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
180    compute_hex(args, true)
181}
182
183pub fn compute_hex(
184    args: &[ColumnarValue],
185    lowercase: bool,
186) -> Result<ColumnarValue, DataFusionError> {
187    if args.len() != 1 {
188        return Err(DataFusionError::Internal(
189            "hex expects exactly one argument".to_string(),
190        ));
191    }
192
193    let input = match &args[0] {
194        ColumnarValue::Scalar(value) => ColumnarValue::Array(value.to_array()?),
195        ColumnarValue::Array(_) => args[0].clone(),
196    };
197
198    match &input {
199        ColumnarValue::Array(array) => match array.data_type() {
200            DataType::Int64 => {
201                let array = as_int64_array(array)?;
202
203                let hexed_array: StringArray =
204                    array.iter().map(|v| v.map(hex_int64)).collect();
205
206                Ok(ColumnarValue::Array(Arc::new(hexed_array)))
207            }
208            DataType::Utf8 => {
209                let array = as_string_array(array);
210
211                let hexed: StringArray = array
212                    .iter()
213                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
214                    .collect::<Result<_, _>>()?;
215
216                Ok(ColumnarValue::Array(Arc::new(hexed)))
217            }
218            DataType::Utf8View => {
219                let array = as_string_view_array(array)?;
220
221                let hexed: StringArray = array
222                    .iter()
223                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
224                    .collect::<Result<_, _>>()?;
225
226                Ok(ColumnarValue::Array(Arc::new(hexed)))
227            }
228            DataType::LargeUtf8 => {
229                let array = as_largestring_array(array);
230
231                let hexed: StringArray = array
232                    .iter()
233                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
234                    .collect::<Result<_, _>>()?;
235
236                Ok(ColumnarValue::Array(Arc::new(hexed)))
237            }
238            DataType::Binary => {
239                let array = as_binary_array(array)?;
240
241                let hexed: StringArray = array
242                    .iter()
243                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
244                    .collect::<Result<_, _>>()?;
245
246                Ok(ColumnarValue::Array(Arc::new(hexed)))
247            }
248            DataType::FixedSizeBinary(_) => {
249                let array = as_fixed_size_binary_array(array)?;
250
251                let hexed: StringArray = array
252                    .iter()
253                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
254                    .collect::<Result<_, _>>()?;
255
256                Ok(ColumnarValue::Array(Arc::new(hexed)))
257            }
258            DataType::Dictionary(_, value_type) => {
259                let dict = as_dictionary_array::<Int32Type>(&array);
260
261                let values = match **value_type {
262                    DataType::Int64 => as_int64_array(dict.values())?
263                        .iter()
264                        .map(|v| v.map(hex_int64))
265                        .collect::<Vec<_>>(),
266                    DataType::Utf8 => as_string_array(dict.values())
267                        .iter()
268                        .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
269                        .collect::<Result<_, _>>()?,
270                    DataType::Binary => as_binary_array(dict.values())?
271                        .iter()
272                        .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
273                        .collect::<Result<_, _>>()?,
274                    _ => exec_err!(
275                        "hex got an unexpected argument type: {:?}",
276                        array.data_type()
277                    )?,
278                };
279
280                let new_values: Vec<Option<String>> = dict
281                    .keys()
282                    .iter()
283                    .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None))
284                    .collect();
285
286                let string_array_values = StringArray::from(new_values);
287
288                Ok(ColumnarValue::Array(Arc::new(string_array_values)))
289            }
290            _ => exec_err!(
291                "hex got an unexpected argument type: {:?}",
292                array.data_type()
293            ),
294        },
295        _ => exec_err!("native hex does not support scalar values at this time"),
296    }
297}
298
299#[cfg(test)]
300mod test {
301    use std::sync::Arc;
302
303    use arrow::array::{Int64Array, StringArray};
304    use arrow::{
305        array::{
306            as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder,
307            StringBuilder, StringDictionaryBuilder,
308        },
309        datatypes::{Int32Type, Int64Type},
310    };
311    use datafusion_expr::ColumnarValue;
312
313    #[test]
314    fn test_dictionary_hex_utf8() {
315        let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
316        input_builder.append_value("hi");
317        input_builder.append_value("bye");
318        input_builder.append_null();
319        input_builder.append_value("rust");
320        let input = input_builder.finish();
321
322        let mut string_builder = StringBuilder::new();
323        string_builder.append_value("6869");
324        string_builder.append_value("627965");
325        string_builder.append_null();
326        string_builder.append_value("72757374");
327        let expected = string_builder.finish();
328
329        let columnar_value = ColumnarValue::Array(Arc::new(input));
330        let result = super::spark_hex(&[columnar_value]).unwrap();
331
332        let result = match result {
333            ColumnarValue::Array(array) => array,
334            _ => panic!("Expected array"),
335        };
336
337        let result = as_string_array(&result);
338
339        assert_eq!(result, &expected);
340    }
341
342    #[test]
343    fn test_dictionary_hex_int64() {
344        let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
345        input_builder.append_value(1);
346        input_builder.append_value(2);
347        input_builder.append_null();
348        input_builder.append_value(3);
349        let input = input_builder.finish();
350
351        let mut string_builder = StringBuilder::new();
352        string_builder.append_value("1");
353        string_builder.append_value("2");
354        string_builder.append_null();
355        string_builder.append_value("3");
356        let expected = string_builder.finish();
357
358        let columnar_value = ColumnarValue::Array(Arc::new(input));
359        let result = super::spark_hex(&[columnar_value]).unwrap();
360
361        let result = match result {
362            ColumnarValue::Array(array) => array,
363            _ => panic!("Expected array"),
364        };
365
366        let result = as_string_array(&result);
367
368        assert_eq!(result, &expected);
369    }
370
371    #[test]
372    fn test_dictionary_hex_binary() {
373        let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
374        input_builder.append_value("1");
375        input_builder.append_value("j");
376        input_builder.append_null();
377        input_builder.append_value("3");
378        let input = input_builder.finish();
379
380        let mut expected_builder = StringBuilder::new();
381        expected_builder.append_value("31");
382        expected_builder.append_value("6A");
383        expected_builder.append_null();
384        expected_builder.append_value("33");
385        let expected = expected_builder.finish();
386
387        let columnar_value = ColumnarValue::Array(Arc::new(input));
388        let result = super::spark_hex(&[columnar_value]).unwrap();
389
390        let result = match result {
391            ColumnarValue::Array(array) => array,
392            _ => panic!("Expected array"),
393        };
394
395        let result = as_string_array(&result);
396
397        assert_eq!(result, &expected);
398    }
399
400    #[test]
401    fn test_hex_int64() {
402        let num = 1234;
403        let hexed = super::hex_int64(num);
404        assert_eq!(hexed, "4D2".to_string());
405
406        let num = -1;
407        let hexed = super::hex_int64(num);
408        assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string());
409    }
410
411    #[test]
412    fn test_spark_hex_int64() {
413        let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
414        let columnar_value = ColumnarValue::Array(Arc::new(int_array));
415
416        let result = super::spark_hex(&[columnar_value]).unwrap();
417        let result = match result {
418            ColumnarValue::Array(array) => array,
419            _ => panic!("Expected array"),
420        };
421
422        let string_array = as_string_array(&result);
423        let expected_array = StringArray::from(vec![
424            Some("1".to_string()),
425            Some("2".to_string()),
426            None,
427            Some("3".to_string()),
428        ]);
429
430        assert_eq!(string_array, &expected_array);
431    }
432}