datafusion_comet_spark_expr/math_funcs/
hex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::StringArray;
21use arrow::datatypes::DataType;
22use arrow::{
23    array::{as_dictionary_array, as_largestring_array, as_string_array},
24    datatypes::Int32Type,
25};
26use datafusion::common::{
27    cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
28    exec_err, DataFusionError,
29};
30use datafusion::logical_expr::ColumnarValue;
31use std::fmt::Write;
32
33fn hex_int64(num: i64) -> String {
34    format!("{num:X}")
35}
36
37#[inline(always)]
38fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
39    let mut s = String::with_capacity(data.as_ref().len() * 2);
40    if lower_case {
41        for b in data.as_ref() {
42            // Writing to a string never errors, so we can unwrap here.
43            write!(&mut s, "{b:02x}").unwrap();
44        }
45    } else {
46        for b in data.as_ref() {
47            // Writing to a string never errors, so we can unwrap here.
48            write!(&mut s, "{b:02X}").unwrap();
49        }
50    }
51    s
52}
53
54#[inline(always)]
55fn hex_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<String, std::fmt::Error> {
56    let hex_string = hex_encode(bytes, false);
57    Ok(hex_string)
58}
59
60/// Spark-compatible `hex` function
61pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
62    if args.len() != 1 {
63        return Err(DataFusionError::Internal(
64            "hex expects exactly one argument".to_string(),
65        ));
66    }
67
68    match &args[0] {
69        ColumnarValue::Array(array) => match array.data_type() {
70            DataType::Int64 => {
71                let array = as_int64_array(array)?;
72
73                let hexed_array: StringArray = array.iter().map(|v| v.map(hex_int64)).collect();
74
75                Ok(ColumnarValue::Array(Arc::new(hexed_array)))
76            }
77            DataType::Utf8 => {
78                let array = as_string_array(array);
79
80                let hexed: StringArray = array
81                    .iter()
82                    .map(|v| v.map(hex_bytes).transpose())
83                    .collect::<Result<_, _>>()?;
84
85                Ok(ColumnarValue::Array(Arc::new(hexed)))
86            }
87            DataType::LargeUtf8 => {
88                let array = as_largestring_array(array);
89
90                let hexed: StringArray = array
91                    .iter()
92                    .map(|v| v.map(hex_bytes).transpose())
93                    .collect::<Result<_, _>>()?;
94
95                Ok(ColumnarValue::Array(Arc::new(hexed)))
96            }
97            DataType::Binary => {
98                let array = as_binary_array(array)?;
99
100                let hexed: StringArray = array
101                    .iter()
102                    .map(|v| v.map(hex_bytes).transpose())
103                    .collect::<Result<_, _>>()?;
104
105                Ok(ColumnarValue::Array(Arc::new(hexed)))
106            }
107            DataType::FixedSizeBinary(_) => {
108                let array = as_fixed_size_binary_array(array)?;
109
110                let hexed: StringArray = array
111                    .iter()
112                    .map(|v| v.map(hex_bytes).transpose())
113                    .collect::<Result<_, _>>()?;
114
115                Ok(ColumnarValue::Array(Arc::new(hexed)))
116            }
117            DataType::Dictionary(_, value_type) => {
118                let dict = as_dictionary_array::<Int32Type>(&array);
119
120                let values = match **value_type {
121                    DataType::Int64 => as_int64_array(dict.values())?
122                        .iter()
123                        .map(|v| v.map(hex_int64))
124                        .collect::<Vec<_>>(),
125                    DataType::Utf8 => as_string_array(dict.values())
126                        .iter()
127                        .map(|v| v.map(hex_bytes).transpose())
128                        .collect::<Result<_, _>>()?,
129                    DataType::Binary => as_binary_array(dict.values())?
130                        .iter()
131                        .map(|v| v.map(hex_bytes).transpose())
132                        .collect::<Result<_, _>>()?,
133                    _ => exec_err!(
134                        "hex got an unexpected argument type: {:?}",
135                        array.data_type()
136                    )?,
137                };
138
139                let new_values: Vec<Option<String>> = dict
140                    .keys()
141                    .iter()
142                    .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None))
143                    .collect();
144
145                let string_array_values = StringArray::from(new_values);
146
147                Ok(ColumnarValue::Array(Arc::new(string_array_values)))
148            }
149            _ => exec_err!(
150                "hex got an unexpected argument type: {:?}",
151                array.data_type()
152            ),
153        },
154        _ => exec_err!("native hex does not support scalar values at this time"),
155    }
156}
157
158#[cfg(test)]
159mod test {
160    use std::sync::Arc;
161
162    use arrow::array::{Int64Array, StringArray};
163    use arrow::{
164        array::{
165            as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder,
166            StringDictionaryBuilder,
167        },
168        datatypes::{Int32Type, Int64Type},
169    };
170    use datafusion::logical_expr::ColumnarValue;
171
172    #[test]
173    fn test_dictionary_hex_utf8() {
174        let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
175        input_builder.append_value("hi");
176        input_builder.append_value("bye");
177        input_builder.append_null();
178        input_builder.append_value("rust");
179        let input = input_builder.finish();
180
181        let mut string_builder = StringBuilder::new();
182        string_builder.append_value("6869");
183        string_builder.append_value("627965");
184        string_builder.append_null();
185        string_builder.append_value("72757374");
186        let expected = string_builder.finish();
187
188        let columnar_value = ColumnarValue::Array(Arc::new(input));
189        let result = super::spark_hex(&[columnar_value]).unwrap();
190
191        let result = match result {
192            ColumnarValue::Array(array) => array,
193            _ => panic!("Expected array"),
194        };
195
196        let result = as_string_array(&result);
197
198        assert_eq!(result, &expected);
199    }
200
201    #[test]
202    fn test_dictionary_hex_int64() {
203        let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
204        input_builder.append_value(1);
205        input_builder.append_value(2);
206        input_builder.append_null();
207        input_builder.append_value(3);
208        let input = input_builder.finish();
209
210        let mut string_builder = StringBuilder::new();
211        string_builder.append_value("1");
212        string_builder.append_value("2");
213        string_builder.append_null();
214        string_builder.append_value("3");
215        let expected = string_builder.finish();
216
217        let columnar_value = ColumnarValue::Array(Arc::new(input));
218        let result = super::spark_hex(&[columnar_value]).unwrap();
219
220        let result = match result {
221            ColumnarValue::Array(array) => array,
222            _ => panic!("Expected array"),
223        };
224
225        let result = as_string_array(&result);
226
227        assert_eq!(result, &expected);
228    }
229
230    #[test]
231    fn test_dictionary_hex_binary() {
232        let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
233        input_builder.append_value("1");
234        input_builder.append_value("j");
235        input_builder.append_null();
236        input_builder.append_value("3");
237        let input = input_builder.finish();
238
239        let mut expected_builder = StringBuilder::new();
240        expected_builder.append_value("31");
241        expected_builder.append_value("6A");
242        expected_builder.append_null();
243        expected_builder.append_value("33");
244        let expected = expected_builder.finish();
245
246        let columnar_value = ColumnarValue::Array(Arc::new(input));
247        let result = super::spark_hex(&[columnar_value]).unwrap();
248
249        let result = match result {
250            ColumnarValue::Array(array) => array,
251            _ => panic!("Expected array"),
252        };
253
254        let result = as_string_array(&result);
255
256        assert_eq!(result, &expected);
257    }
258
259    #[test]
260    fn test_hex_int64() {
261        let num = 1234;
262        let hexed = super::hex_int64(num);
263        assert_eq!(hexed, "4D2".to_string());
264
265        let num = -1;
266        let hexed = super::hex_int64(num);
267        assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string());
268    }
269
270    #[test]
271    fn test_spark_hex_int64() {
272        let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
273        let columnar_value = ColumnarValue::Array(Arc::new(int_array));
274
275        let result = super::spark_hex(&[columnar_value]).unwrap();
276        let result = match result {
277            ColumnarValue::Array(array) => array,
278            _ => panic!("Expected array"),
279        };
280
281        let string_array = as_string_array(&result);
282        let expected_array = StringArray::from(vec![
283            Some("1".to_string()),
284            Some("2".to_string()),
285            None,
286            Some("3".to_string()),
287        ]);
288
289        assert_eq!(string_array, &expected_array);
290    }
291}