datafusion_spark/function/math/
hex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::function::error_utils::{
22    invalid_arg_count_exec_err, unsupported_data_type_exec_err,
23};
24use arrow::array::{Array, StringArray};
25use arrow::datatypes::DataType;
26use arrow::{
27    array::{as_dictionary_array, as_largestring_array, as_string_array},
28    datatypes::Int32Type,
29};
30use datafusion_common::cast::as_string_view_array;
31use datafusion_common::{
32    cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
33    exec_err, internal_err, DataFusionError,
34};
35use datafusion_expr::Signature;
36use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
37use std::fmt::Write;
38
39/// <https://spark.apache.org/docs/latest/api/sql/index.html#hex>
40#[derive(Debug, PartialEq, Eq, Hash)]
41pub struct SparkHex {
42    signature: Signature,
43    aliases: Vec<String>,
44}
45
46impl Default for SparkHex {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl SparkHex {
53    pub fn new() -> Self {
54        Self {
55            signature: Signature::user_defined(Volatility::Immutable),
56            aliases: vec![],
57        }
58    }
59}
60
61impl ScalarUDFImpl for SparkHex {
62    fn as_any(&self) -> &dyn Any {
63        self
64    }
65
66    fn name(&self) -> &str {
67        "hex"
68    }
69
70    fn signature(&self) -> &Signature {
71        &self.signature
72    }
73
74    fn return_type(
75        &self,
76        _arg_types: &[DataType],
77    ) -> datafusion_common::Result<DataType> {
78        Ok(DataType::Utf8)
79    }
80
81    fn invoke_with_args(
82        &self,
83        args: ScalarFunctionArgs,
84    ) -> datafusion_common::Result<ColumnarValue> {
85        spark_hex(&args.args)
86    }
87
88    fn aliases(&self) -> &[String] {
89        &self.aliases
90    }
91
92    fn coerce_types(
93        &self,
94        arg_types: &[DataType],
95    ) -> datafusion_common::Result<Vec<DataType>> {
96        if arg_types.len() != 1 {
97            return Err(invalid_arg_count_exec_err("hex", (1, 1), arg_types.len()));
98        }
99        match &arg_types[0] {
100            DataType::Int64
101            | DataType::Utf8
102            | DataType::Utf8View
103            | DataType::LargeUtf8
104            | DataType::Binary
105            | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]),
106            DataType::Dictionary(key_type, value_type) => match value_type.as_ref() {
107                DataType::Int64
108                | DataType::Utf8
109                | DataType::Utf8View
110                | DataType::LargeUtf8
111                | DataType::Binary
112                | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]),
113                other => {
114                    if other.is_numeric() {
115                        Ok(vec![DataType::Dictionary(
116                            key_type.clone(),
117                            Box::new(DataType::Int64),
118                        )])
119                    } else {
120                        Err(unsupported_data_type_exec_err(
121                            "hex",
122                            "Numeric, String, or Binary",
123                            &arg_types[0],
124                        ))
125                    }
126                }
127            },
128            other => {
129                if other.is_numeric() {
130                    Ok(vec![DataType::Int64])
131                } else {
132                    Err(unsupported_data_type_exec_err(
133                        "hex",
134                        "Numeric, String, or Binary",
135                        &arg_types[0],
136                    ))
137                }
138            }
139        }
140    }
141}
142
143fn hex_int64(num: i64) -> String {
144    format!("{num:X}")
145}
146
147#[inline(always)]
148fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
149    let mut s = String::with_capacity(data.as_ref().len() * 2);
150    if lower_case {
151        for b in data.as_ref() {
152            // Writing to a string never errors, so we can unwrap here.
153            write!(&mut s, "{b:02x}").unwrap();
154        }
155    } else {
156        for b in data.as_ref() {
157            // Writing to a string never errors, so we can unwrap here.
158            write!(&mut s, "{b:02X}").unwrap();
159        }
160    }
161    s
162}
163
164#[inline(always)]
165fn hex_bytes<T: AsRef<[u8]>>(
166    bytes: T,
167    lowercase: bool,
168) -> Result<String, std::fmt::Error> {
169    let hex_string = hex_encode(bytes, lowercase);
170    Ok(hex_string)
171}
172
173/// Spark-compatible `hex` function
174pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
175    compute_hex(args, false)
176}
177
178/// Spark-compatible `sha2` function
179pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
180    compute_hex(args, true)
181}
182
183pub fn compute_hex(
184    args: &[ColumnarValue],
185    lowercase: bool,
186) -> Result<ColumnarValue, DataFusionError> {
187    if args.len() != 1 {
188        return internal_err!("hex expects exactly one argument");
189    }
190
191    let input = match &args[0] {
192        ColumnarValue::Scalar(value) => ColumnarValue::Array(value.to_array()?),
193        ColumnarValue::Array(_) => args[0].clone(),
194    };
195
196    match &input {
197        ColumnarValue::Array(array) => match array.data_type() {
198            DataType::Int64 => {
199                let array = as_int64_array(array)?;
200
201                let hexed_array: StringArray =
202                    array.iter().map(|v| v.map(hex_int64)).collect();
203
204                Ok(ColumnarValue::Array(Arc::new(hexed_array)))
205            }
206            DataType::Utf8 => {
207                let array = as_string_array(array);
208
209                let hexed: StringArray = array
210                    .iter()
211                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
212                    .collect::<Result<_, _>>()?;
213
214                Ok(ColumnarValue::Array(Arc::new(hexed)))
215            }
216            DataType::Utf8View => {
217                let array = as_string_view_array(array)?;
218
219                let hexed: StringArray = array
220                    .iter()
221                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
222                    .collect::<Result<_, _>>()?;
223
224                Ok(ColumnarValue::Array(Arc::new(hexed)))
225            }
226            DataType::LargeUtf8 => {
227                let array = as_largestring_array(array);
228
229                let hexed: StringArray = array
230                    .iter()
231                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
232                    .collect::<Result<_, _>>()?;
233
234                Ok(ColumnarValue::Array(Arc::new(hexed)))
235            }
236            DataType::Binary => {
237                let array = as_binary_array(array)?;
238
239                let hexed: StringArray = array
240                    .iter()
241                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
242                    .collect::<Result<_, _>>()?;
243
244                Ok(ColumnarValue::Array(Arc::new(hexed)))
245            }
246            DataType::FixedSizeBinary(_) => {
247                let array = as_fixed_size_binary_array(array)?;
248
249                let hexed: StringArray = array
250                    .iter()
251                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
252                    .collect::<Result<_, _>>()?;
253
254                Ok(ColumnarValue::Array(Arc::new(hexed)))
255            }
256            DataType::Dictionary(_, value_type) => {
257                let dict = as_dictionary_array::<Int32Type>(&array);
258
259                let values = match **value_type {
260                    DataType::Int64 => as_int64_array(dict.values())?
261                        .iter()
262                        .map(|v| v.map(hex_int64))
263                        .collect::<Vec<_>>(),
264                    DataType::Utf8 => as_string_array(dict.values())
265                        .iter()
266                        .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
267                        .collect::<Result<_, _>>()?,
268                    DataType::Binary => as_binary_array(dict.values())?
269                        .iter()
270                        .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
271                        .collect::<Result<_, _>>()?,
272                    _ => exec_err!(
273                        "hex got an unexpected argument type: {}",
274                        array.data_type()
275                    )?,
276                };
277
278                let new_values: Vec<Option<String>> = dict
279                    .keys()
280                    .iter()
281                    .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None))
282                    .collect();
283
284                let string_array_values = StringArray::from(new_values);
285
286                Ok(ColumnarValue::Array(Arc::new(string_array_values)))
287            }
288            _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()),
289        },
290        _ => exec_err!("native hex does not support scalar values at this time"),
291    }
292}
293
294#[cfg(test)]
295mod test {
296    use std::sync::Arc;
297
298    use arrow::array::{Int64Array, StringArray};
299    use arrow::{
300        array::{
301            as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder,
302            StringBuilder, StringDictionaryBuilder,
303        },
304        datatypes::{Int32Type, Int64Type},
305    };
306    use datafusion_expr::ColumnarValue;
307
308    #[test]
309    fn test_dictionary_hex_utf8() {
310        let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
311        input_builder.append_value("hi");
312        input_builder.append_value("bye");
313        input_builder.append_null();
314        input_builder.append_value("rust");
315        let input = input_builder.finish();
316
317        let mut string_builder = StringBuilder::new();
318        string_builder.append_value("6869");
319        string_builder.append_value("627965");
320        string_builder.append_null();
321        string_builder.append_value("72757374");
322        let expected = string_builder.finish();
323
324        let columnar_value = ColumnarValue::Array(Arc::new(input));
325        let result = super::spark_hex(&[columnar_value]).unwrap();
326
327        let result = match result {
328            ColumnarValue::Array(array) => array,
329            _ => panic!("Expected array"),
330        };
331
332        let result = as_string_array(&result);
333
334        assert_eq!(result, &expected);
335    }
336
337    #[test]
338    fn test_dictionary_hex_int64() {
339        let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
340        input_builder.append_value(1);
341        input_builder.append_value(2);
342        input_builder.append_null();
343        input_builder.append_value(3);
344        let input = input_builder.finish();
345
346        let mut string_builder = StringBuilder::new();
347        string_builder.append_value("1");
348        string_builder.append_value("2");
349        string_builder.append_null();
350        string_builder.append_value("3");
351        let expected = string_builder.finish();
352
353        let columnar_value = ColumnarValue::Array(Arc::new(input));
354        let result = super::spark_hex(&[columnar_value]).unwrap();
355
356        let result = match result {
357            ColumnarValue::Array(array) => array,
358            _ => panic!("Expected array"),
359        };
360
361        let result = as_string_array(&result);
362
363        assert_eq!(result, &expected);
364    }
365
366    #[test]
367    fn test_dictionary_hex_binary() {
368        let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
369        input_builder.append_value("1");
370        input_builder.append_value("j");
371        input_builder.append_null();
372        input_builder.append_value("3");
373        let input = input_builder.finish();
374
375        let mut expected_builder = StringBuilder::new();
376        expected_builder.append_value("31");
377        expected_builder.append_value("6A");
378        expected_builder.append_null();
379        expected_builder.append_value("33");
380        let expected = expected_builder.finish();
381
382        let columnar_value = ColumnarValue::Array(Arc::new(input));
383        let result = super::spark_hex(&[columnar_value]).unwrap();
384
385        let result = match result {
386            ColumnarValue::Array(array) => array,
387            _ => panic!("Expected array"),
388        };
389
390        let result = as_string_array(&result);
391
392        assert_eq!(result, &expected);
393    }
394
395    #[test]
396    fn test_hex_int64() {
397        let num = 1234;
398        let hexed = super::hex_int64(num);
399        assert_eq!(hexed, "4D2".to_string());
400
401        let num = -1;
402        let hexed = super::hex_int64(num);
403        assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string());
404    }
405
406    #[test]
407    fn test_spark_hex_int64() {
408        let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
409        let columnar_value = ColumnarValue::Array(Arc::new(int_array));
410
411        let result = super::spark_hex(&[columnar_value]).unwrap();
412        let result = match result {
413            ColumnarValue::Array(array) => array,
414            _ => panic!("Expected array"),
415        };
416
417        let string_array = as_string_array(&result);
418        let expected_array = StringArray::from(vec![
419            Some("1".to_string()),
420            Some("2".to_string()),
421            None,
422            Some("3".to_string()),
423        ]);
424
425        assert_eq!(string_array, &expected_array);
426    }
427}