datafusion_comet_spark_expr/math_funcs/
hex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::{
21    array::{as_dictionary_array, as_largestring_array, as_string_array},
22    datatypes::Int32Type,
23};
24use arrow_array::StringArray;
25use arrow_schema::DataType;
26use datafusion::logical_expr::ColumnarValue;
27use datafusion_common::{
28    cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
29    exec_err, DataFusionError,
30};
31use std::fmt::Write;
32
33fn hex_int64(num: i64) -> String {
34    format!("{:X}", num)
35}
36
37#[inline(always)]
38fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
39    let mut s = String::with_capacity(data.as_ref().len() * 2);
40    if lower_case {
41        for b in data.as_ref() {
42            // Writing to a string never errors, so we can unwrap here.
43            write!(&mut s, "{b:02x}").unwrap();
44        }
45    } else {
46        for b in data.as_ref() {
47            // Writing to a string never errors, so we can unwrap here.
48            write!(&mut s, "{b:02X}").unwrap();
49        }
50    }
51    s
52}
53
54#[inline(always)]
55pub(crate) fn hex_strings<T: AsRef<[u8]>>(data: T) -> String {
56    hex_encode(data, true)
57}
58
59#[inline(always)]
60fn hex_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<String, std::fmt::Error> {
61    let hex_string = hex_encode(bytes, false);
62    Ok(hex_string)
63}
64
65/// Spark-compatible `hex` function
66pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
67    if args.len() != 1 {
68        return Err(DataFusionError::Internal(
69            "hex expects exactly one argument".to_string(),
70        ));
71    }
72
73    match &args[0] {
74        ColumnarValue::Array(array) => match array.data_type() {
75            DataType::Int64 => {
76                let array = as_int64_array(array)?;
77
78                let hexed_array: StringArray = array.iter().map(|v| v.map(hex_int64)).collect();
79
80                Ok(ColumnarValue::Array(Arc::new(hexed_array)))
81            }
82            DataType::Utf8 => {
83                let array = as_string_array(array);
84
85                let hexed: StringArray = array
86                    .iter()
87                    .map(|v| v.map(hex_bytes).transpose())
88                    .collect::<Result<_, _>>()?;
89
90                Ok(ColumnarValue::Array(Arc::new(hexed)))
91            }
92            DataType::LargeUtf8 => {
93                let array = as_largestring_array(array);
94
95                let hexed: StringArray = array
96                    .iter()
97                    .map(|v| v.map(hex_bytes).transpose())
98                    .collect::<Result<_, _>>()?;
99
100                Ok(ColumnarValue::Array(Arc::new(hexed)))
101            }
102            DataType::Binary => {
103                let array = as_binary_array(array)?;
104
105                let hexed: StringArray = array
106                    .iter()
107                    .map(|v| v.map(hex_bytes).transpose())
108                    .collect::<Result<_, _>>()?;
109
110                Ok(ColumnarValue::Array(Arc::new(hexed)))
111            }
112            DataType::FixedSizeBinary(_) => {
113                let array = as_fixed_size_binary_array(array)?;
114
115                let hexed: StringArray = array
116                    .iter()
117                    .map(|v| v.map(hex_bytes).transpose())
118                    .collect::<Result<_, _>>()?;
119
120                Ok(ColumnarValue::Array(Arc::new(hexed)))
121            }
122            DataType::Dictionary(_, value_type) => {
123                let dict = as_dictionary_array::<Int32Type>(&array);
124
125                let values = match **value_type {
126                    DataType::Int64 => as_int64_array(dict.values())?
127                        .iter()
128                        .map(|v| v.map(hex_int64))
129                        .collect::<Vec<_>>(),
130                    DataType::Utf8 => as_string_array(dict.values())
131                        .iter()
132                        .map(|v| v.map(hex_bytes).transpose())
133                        .collect::<Result<_, _>>()?,
134                    DataType::Binary => as_binary_array(dict.values())?
135                        .iter()
136                        .map(|v| v.map(hex_bytes).transpose())
137                        .collect::<Result<_, _>>()?,
138                    _ => exec_err!(
139                        "hex got an unexpected argument type: {:?}",
140                        array.data_type()
141                    )?,
142                };
143
144                let new_values: Vec<Option<String>> = dict
145                    .keys()
146                    .iter()
147                    .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None))
148                    .collect();
149
150                let string_array_values = StringArray::from(new_values);
151
152                Ok(ColumnarValue::Array(Arc::new(string_array_values)))
153            }
154            _ => exec_err!(
155                "hex got an unexpected argument type: {:?}",
156                array.data_type()
157            ),
158        },
159        _ => exec_err!("native hex does not support scalar values at this time"),
160    }
161}
162
163#[cfg(test)]
164mod test {
165    use std::sync::Arc;
166
167    use arrow::{
168        array::{
169            as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder,
170            StringDictionaryBuilder,
171        },
172        datatypes::{Int32Type, Int64Type},
173    };
174    use arrow_array::{Int64Array, StringArray};
175    use datafusion::logical_expr::ColumnarValue;
176
177    #[test]
178    fn test_dictionary_hex_utf8() {
179        let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
180        input_builder.append_value("hi");
181        input_builder.append_value("bye");
182        input_builder.append_null();
183        input_builder.append_value("rust");
184        let input = input_builder.finish();
185
186        let mut string_builder = StringBuilder::new();
187        string_builder.append_value("6869");
188        string_builder.append_value("627965");
189        string_builder.append_null();
190        string_builder.append_value("72757374");
191        let expected = string_builder.finish();
192
193        let columnar_value = ColumnarValue::Array(Arc::new(input));
194        let result = super::spark_hex(&[columnar_value]).unwrap();
195
196        let result = match result {
197            ColumnarValue::Array(array) => array,
198            _ => panic!("Expected array"),
199        };
200
201        let result = as_string_array(&result);
202
203        assert_eq!(result, &expected);
204    }
205
206    #[test]
207    fn test_dictionary_hex_int64() {
208        let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
209        input_builder.append_value(1);
210        input_builder.append_value(2);
211        input_builder.append_null();
212        input_builder.append_value(3);
213        let input = input_builder.finish();
214
215        let mut string_builder = StringBuilder::new();
216        string_builder.append_value("1");
217        string_builder.append_value("2");
218        string_builder.append_null();
219        string_builder.append_value("3");
220        let expected = string_builder.finish();
221
222        let columnar_value = ColumnarValue::Array(Arc::new(input));
223        let result = super::spark_hex(&[columnar_value]).unwrap();
224
225        let result = match result {
226            ColumnarValue::Array(array) => array,
227            _ => panic!("Expected array"),
228        };
229
230        let result = as_string_array(&result);
231
232        assert_eq!(result, &expected);
233    }
234
235    #[test]
236    fn test_dictionary_hex_binary() {
237        let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
238        input_builder.append_value("1");
239        input_builder.append_value("j");
240        input_builder.append_null();
241        input_builder.append_value("3");
242        let input = input_builder.finish();
243
244        let mut expected_builder = StringBuilder::new();
245        expected_builder.append_value("31");
246        expected_builder.append_value("6A");
247        expected_builder.append_null();
248        expected_builder.append_value("33");
249        let expected = expected_builder.finish();
250
251        let columnar_value = ColumnarValue::Array(Arc::new(input));
252        let result = super::spark_hex(&[columnar_value]).unwrap();
253
254        let result = match result {
255            ColumnarValue::Array(array) => array,
256            _ => panic!("Expected array"),
257        };
258
259        let result = as_string_array(&result);
260
261        assert_eq!(result, &expected);
262    }
263
264    #[test]
265    fn test_hex_int64() {
266        let num = 1234;
267        let hexed = super::hex_int64(num);
268        assert_eq!(hexed, "4D2".to_string());
269
270        let num = -1;
271        let hexed = super::hex_int64(num);
272        assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string());
273    }
274
275    #[test]
276    fn test_spark_hex_int64() {
277        let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
278        let columnar_value = ColumnarValue::Array(Arc::new(int_array));
279
280        let result = super::spark_hex(&[columnar_value]).unwrap();
281        let result = match result {
282            ColumnarValue::Array(array) => array,
283            _ => panic!("Expected array"),
284        };
285
286        let string_array = as_string_array(&result);
287        let expected_array = StringArray::from(vec![
288            Some("1".to_string()),
289            Some("2".to_string()),
290            None,
291            Some("3".to_string()),
292        ]);
293
294        assert_eq!(string_array, &expected_array);
295    }
296}