datafusion_spark/function/math/
hex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::function::error_utils::{
22    invalid_arg_count_exec_err, unsupported_data_type_exec_err,
23};
24use arrow::array::{Array, StringArray};
25use arrow::datatypes::DataType;
26use arrow::{
27    array::{as_dictionary_array, as_largestring_array, as_string_array},
28    datatypes::Int32Type,
29};
30use datafusion_common::{
31    cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
32    exec_err, DataFusionError,
33};
34use datafusion_expr::Signature;
35use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
36use std::fmt::Write;
37
38/// <https://spark.apache.org/docs/latest/api/sql/index.html#hex>
39#[derive(Debug)]
40pub struct SparkHex {
41    signature: Signature,
42    aliases: Vec<String>,
43}
44
45impl Default for SparkHex {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl SparkHex {
52    pub fn new() -> Self {
53        Self {
54            signature: Signature::user_defined(Volatility::Immutable),
55            aliases: vec![],
56        }
57    }
58}
59
60impl ScalarUDFImpl for SparkHex {
61    fn as_any(&self) -> &dyn Any {
62        self
63    }
64
65    fn name(&self) -> &str {
66        "hex"
67    }
68
69    fn signature(&self) -> &Signature {
70        &self.signature
71    }
72
73    fn return_type(
74        &self,
75        _arg_types: &[DataType],
76    ) -> datafusion_common::Result<DataType> {
77        Ok(DataType::Utf8)
78    }
79
80    fn invoke_with_args(
81        &self,
82        args: ScalarFunctionArgs,
83    ) -> datafusion_common::Result<ColumnarValue> {
84        spark_hex(&args.args)
85    }
86
87    fn aliases(&self) -> &[String] {
88        &self.aliases
89    }
90
91    fn coerce_types(
92        &self,
93        arg_types: &[DataType],
94    ) -> datafusion_common::Result<Vec<DataType>> {
95        if arg_types.len() != 1 {
96            return Err(invalid_arg_count_exec_err("hex", (1, 1), arg_types.len()));
97        }
98        match &arg_types[0] {
99            DataType::Int64
100            | DataType::Utf8
101            | DataType::LargeUtf8
102            | DataType::Binary
103            | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]),
104            DataType::Dictionary(key_type, value_type) => match value_type.as_ref() {
105                DataType::Int64
106                | DataType::Utf8
107                | DataType::LargeUtf8
108                | DataType::Binary
109                | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]),
110                other => {
111                    if other.is_numeric() {
112                        Ok(vec![DataType::Dictionary(
113                            key_type.clone(),
114                            Box::new(DataType::Int64),
115                        )])
116                    } else {
117                        Err(unsupported_data_type_exec_err(
118                            "hex",
119                            "Numeric, String, or Binary",
120                            &arg_types[0],
121                        ))
122                    }
123                }
124            },
125            other => {
126                if other.is_numeric() {
127                    Ok(vec![DataType::Int64])
128                } else {
129                    Err(unsupported_data_type_exec_err(
130                        "hex",
131                        "Numeric, String, or Binary",
132                        &arg_types[0],
133                    ))
134                }
135            }
136        }
137    }
138}
139
140fn hex_int64(num: i64) -> String {
141    format!("{num:X}")
142}
143
144#[inline(always)]
145fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
146    let mut s = String::with_capacity(data.as_ref().len() * 2);
147    if lower_case {
148        for b in data.as_ref() {
149            // Writing to a string never errors, so we can unwrap here.
150            write!(&mut s, "{b:02x}").unwrap();
151        }
152    } else {
153        for b in data.as_ref() {
154            // Writing to a string never errors, so we can unwrap here.
155            write!(&mut s, "{b:02X}").unwrap();
156        }
157    }
158    s
159}
160
161#[inline(always)]
162fn hex_bytes<T: AsRef<[u8]>>(
163    bytes: T,
164    lowercase: bool,
165) -> Result<String, std::fmt::Error> {
166    let hex_string = hex_encode(bytes, lowercase);
167    Ok(hex_string)
168}
169
170/// Spark-compatible `hex` function
171pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
172    compute_hex(args, false)
173}
174
175/// Spark-compatible `sha2` function
176pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
177    compute_hex(args, true)
178}
179
180pub fn compute_hex(
181    args: &[ColumnarValue],
182    lowercase: bool,
183) -> Result<ColumnarValue, DataFusionError> {
184    if args.len() != 1 {
185        return Err(DataFusionError::Internal(
186            "hex expects exactly one argument".to_string(),
187        ));
188    }
189
190    let input = match &args[0] {
191        ColumnarValue::Scalar(value) => ColumnarValue::Array(value.to_array()?),
192        ColumnarValue::Array(_) => args[0].clone(),
193    };
194
195    match &input {
196        ColumnarValue::Array(array) => match array.data_type() {
197            DataType::Int64 => {
198                let array = as_int64_array(array)?;
199
200                let hexed_array: StringArray =
201                    array.iter().map(|v| v.map(hex_int64)).collect();
202
203                Ok(ColumnarValue::Array(Arc::new(hexed_array)))
204            }
205            DataType::Utf8 => {
206                let array = as_string_array(array);
207
208                let hexed: StringArray = array
209                    .iter()
210                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
211                    .collect::<Result<_, _>>()?;
212
213                Ok(ColumnarValue::Array(Arc::new(hexed)))
214            }
215            DataType::LargeUtf8 => {
216                let array = as_largestring_array(array);
217
218                let hexed: StringArray = array
219                    .iter()
220                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
221                    .collect::<Result<_, _>>()?;
222
223                Ok(ColumnarValue::Array(Arc::new(hexed)))
224            }
225            DataType::Binary => {
226                let array = as_binary_array(array)?;
227
228                let hexed: StringArray = array
229                    .iter()
230                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
231                    .collect::<Result<_, _>>()?;
232
233                Ok(ColumnarValue::Array(Arc::new(hexed)))
234            }
235            DataType::FixedSizeBinary(_) => {
236                let array = as_fixed_size_binary_array(array)?;
237
238                let hexed: StringArray = array
239                    .iter()
240                    .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
241                    .collect::<Result<_, _>>()?;
242
243                Ok(ColumnarValue::Array(Arc::new(hexed)))
244            }
245            DataType::Dictionary(_, value_type) => {
246                let dict = as_dictionary_array::<Int32Type>(&array);
247
248                let values = match **value_type {
249                    DataType::Int64 => as_int64_array(dict.values())?
250                        .iter()
251                        .map(|v| v.map(hex_int64))
252                        .collect::<Vec<_>>(),
253                    DataType::Utf8 => as_string_array(dict.values())
254                        .iter()
255                        .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
256                        .collect::<Result<_, _>>()?,
257                    DataType::Binary => as_binary_array(dict.values())?
258                        .iter()
259                        .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
260                        .collect::<Result<_, _>>()?,
261                    _ => exec_err!(
262                        "hex got an unexpected argument type: {:?}",
263                        array.data_type()
264                    )?,
265                };
266
267                let new_values: Vec<Option<String>> = dict
268                    .keys()
269                    .iter()
270                    .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None))
271                    .collect();
272
273                let string_array_values = StringArray::from(new_values);
274
275                Ok(ColumnarValue::Array(Arc::new(string_array_values)))
276            }
277            _ => exec_err!(
278                "hex got an unexpected argument type: {:?}",
279                array.data_type()
280            ),
281        },
282        _ => exec_err!("native hex does not support scalar values at this time"),
283    }
284}
285
286#[cfg(test)]
287mod test {
288    use std::sync::Arc;
289
290    use arrow::array::{Int64Array, StringArray};
291    use arrow::{
292        array::{
293            as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder,
294            StringBuilder, StringDictionaryBuilder,
295        },
296        datatypes::{Int32Type, Int64Type},
297    };
298    use datafusion_expr::ColumnarValue;
299
300    #[test]
301    fn test_dictionary_hex_utf8() {
302        let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
303        input_builder.append_value("hi");
304        input_builder.append_value("bye");
305        input_builder.append_null();
306        input_builder.append_value("rust");
307        let input = input_builder.finish();
308
309        let mut string_builder = StringBuilder::new();
310        string_builder.append_value("6869");
311        string_builder.append_value("627965");
312        string_builder.append_null();
313        string_builder.append_value("72757374");
314        let expected = string_builder.finish();
315
316        let columnar_value = ColumnarValue::Array(Arc::new(input));
317        let result = super::spark_hex(&[columnar_value]).unwrap();
318
319        let result = match result {
320            ColumnarValue::Array(array) => array,
321            _ => panic!("Expected array"),
322        };
323
324        let result = as_string_array(&result);
325
326        assert_eq!(result, &expected);
327    }
328
329    #[test]
330    fn test_dictionary_hex_int64() {
331        let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
332        input_builder.append_value(1);
333        input_builder.append_value(2);
334        input_builder.append_null();
335        input_builder.append_value(3);
336        let input = input_builder.finish();
337
338        let mut string_builder = StringBuilder::new();
339        string_builder.append_value("1");
340        string_builder.append_value("2");
341        string_builder.append_null();
342        string_builder.append_value("3");
343        let expected = string_builder.finish();
344
345        let columnar_value = ColumnarValue::Array(Arc::new(input));
346        let result = super::spark_hex(&[columnar_value]).unwrap();
347
348        let result = match result {
349            ColumnarValue::Array(array) => array,
350            _ => panic!("Expected array"),
351        };
352
353        let result = as_string_array(&result);
354
355        assert_eq!(result, &expected);
356    }
357
358    #[test]
359    fn test_dictionary_hex_binary() {
360        let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
361        input_builder.append_value("1");
362        input_builder.append_value("j");
363        input_builder.append_null();
364        input_builder.append_value("3");
365        let input = input_builder.finish();
366
367        let mut expected_builder = StringBuilder::new();
368        expected_builder.append_value("31");
369        expected_builder.append_value("6A");
370        expected_builder.append_null();
371        expected_builder.append_value("33");
372        let expected = expected_builder.finish();
373
374        let columnar_value = ColumnarValue::Array(Arc::new(input));
375        let result = super::spark_hex(&[columnar_value]).unwrap();
376
377        let result = match result {
378            ColumnarValue::Array(array) => array,
379            _ => panic!("Expected array"),
380        };
381
382        let result = as_string_array(&result);
383
384        assert_eq!(result, &expected);
385    }
386
387    #[test]
388    fn test_hex_int64() {
389        let num = 1234;
390        let hexed = super::hex_int64(num);
391        assert_eq!(hexed, "4D2".to_string());
392
393        let num = -1;
394        let hexed = super::hex_int64(num);
395        assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string());
396    }
397
398    #[test]
399    fn test_spark_hex_int64() {
400        let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
401        let columnar_value = ColumnarValue::Array(Arc::new(int_array));
402
403        let result = super::spark_hex(&[columnar_value]).unwrap();
404        let result = match result {
405            ColumnarValue::Array(array) => array,
406            _ => panic!("Expected array"),
407        };
408
409        let string_array = as_string_array(&result);
410        let expected_array = StringArray::from(vec![
411            Some("1".to_string()),
412            Some("2".to_string()),
413            None,
414            Some("3".to_string()),
415        ]);
416
417        assert_eq!(string_array, &expected_array);
418    }
419}