datafusion_spark/function/math/
hex.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::function::error_utils::{
22 invalid_arg_count_exec_err, unsupported_data_type_exec_err,
23};
24use arrow::array::{Array, StringArray};
25use arrow::datatypes::DataType;
26use arrow::{
27 array::{as_dictionary_array, as_largestring_array, as_string_array},
28 datatypes::Int32Type,
29};
30use datafusion_common::{
31 cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
32 exec_err, DataFusionError,
33};
34use datafusion_expr::Signature;
35use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
36use std::fmt::Write;
37
38#[derive(Debug)]
40pub struct SparkHex {
41 signature: Signature,
42 aliases: Vec<String>,
43}
44
45impl Default for SparkHex {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl SparkHex {
52 pub fn new() -> Self {
53 Self {
54 signature: Signature::user_defined(Volatility::Immutable),
55 aliases: vec![],
56 }
57 }
58}
59
60impl ScalarUDFImpl for SparkHex {
61 fn as_any(&self) -> &dyn Any {
62 self
63 }
64
65 fn name(&self) -> &str {
66 "hex"
67 }
68
69 fn signature(&self) -> &Signature {
70 &self.signature
71 }
72
73 fn return_type(
74 &self,
75 _arg_types: &[DataType],
76 ) -> datafusion_common::Result<DataType> {
77 Ok(DataType::Utf8)
78 }
79
80 fn invoke_with_args(
81 &self,
82 args: ScalarFunctionArgs,
83 ) -> datafusion_common::Result<ColumnarValue> {
84 spark_hex(&args.args)
85 }
86
87 fn aliases(&self) -> &[String] {
88 &self.aliases
89 }
90
91 fn coerce_types(
92 &self,
93 arg_types: &[DataType],
94 ) -> datafusion_common::Result<Vec<DataType>> {
95 if arg_types.len() != 1 {
96 return Err(invalid_arg_count_exec_err("hex", (1, 1), arg_types.len()));
97 }
98 match &arg_types[0] {
99 DataType::Int64
100 | DataType::Utf8
101 | DataType::LargeUtf8
102 | DataType::Binary
103 | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]),
104 DataType::Dictionary(key_type, value_type) => match value_type.as_ref() {
105 DataType::Int64
106 | DataType::Utf8
107 | DataType::LargeUtf8
108 | DataType::Binary
109 | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]),
110 other => {
111 if other.is_numeric() {
112 Ok(vec![DataType::Dictionary(
113 key_type.clone(),
114 Box::new(DataType::Int64),
115 )])
116 } else {
117 Err(unsupported_data_type_exec_err(
118 "hex",
119 "Numeric, String, or Binary",
120 &arg_types[0],
121 ))
122 }
123 }
124 },
125 other => {
126 if other.is_numeric() {
127 Ok(vec![DataType::Int64])
128 } else {
129 Err(unsupported_data_type_exec_err(
130 "hex",
131 "Numeric, String, or Binary",
132 &arg_types[0],
133 ))
134 }
135 }
136 }
137 }
138}
139
140fn hex_int64(num: i64) -> String {
141 format!("{num:X}")
142}
143
144#[inline(always)]
145fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
146 let mut s = String::with_capacity(data.as_ref().len() * 2);
147 if lower_case {
148 for b in data.as_ref() {
149 write!(&mut s, "{b:02x}").unwrap();
151 }
152 } else {
153 for b in data.as_ref() {
154 write!(&mut s, "{b:02X}").unwrap();
156 }
157 }
158 s
159}
160
161#[inline(always)]
162fn hex_bytes<T: AsRef<[u8]>>(
163 bytes: T,
164 lowercase: bool,
165) -> Result<String, std::fmt::Error> {
166 let hex_string = hex_encode(bytes, lowercase);
167 Ok(hex_string)
168}
169
170pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
172 compute_hex(args, false)
173}
174
175pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
177 compute_hex(args, true)
178}
179
180pub fn compute_hex(
181 args: &[ColumnarValue],
182 lowercase: bool,
183) -> Result<ColumnarValue, DataFusionError> {
184 if args.len() != 1 {
185 return Err(DataFusionError::Internal(
186 "hex expects exactly one argument".to_string(),
187 ));
188 }
189
190 let input = match &args[0] {
191 ColumnarValue::Scalar(value) => ColumnarValue::Array(value.to_array()?),
192 ColumnarValue::Array(_) => args[0].clone(),
193 };
194
195 match &input {
196 ColumnarValue::Array(array) => match array.data_type() {
197 DataType::Int64 => {
198 let array = as_int64_array(array)?;
199
200 let hexed_array: StringArray =
201 array.iter().map(|v| v.map(hex_int64)).collect();
202
203 Ok(ColumnarValue::Array(Arc::new(hexed_array)))
204 }
205 DataType::Utf8 => {
206 let array = as_string_array(array);
207
208 let hexed: StringArray = array
209 .iter()
210 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
211 .collect::<Result<_, _>>()?;
212
213 Ok(ColumnarValue::Array(Arc::new(hexed)))
214 }
215 DataType::LargeUtf8 => {
216 let array = as_largestring_array(array);
217
218 let hexed: StringArray = array
219 .iter()
220 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
221 .collect::<Result<_, _>>()?;
222
223 Ok(ColumnarValue::Array(Arc::new(hexed)))
224 }
225 DataType::Binary => {
226 let array = as_binary_array(array)?;
227
228 let hexed: StringArray = array
229 .iter()
230 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
231 .collect::<Result<_, _>>()?;
232
233 Ok(ColumnarValue::Array(Arc::new(hexed)))
234 }
235 DataType::FixedSizeBinary(_) => {
236 let array = as_fixed_size_binary_array(array)?;
237
238 let hexed: StringArray = array
239 .iter()
240 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
241 .collect::<Result<_, _>>()?;
242
243 Ok(ColumnarValue::Array(Arc::new(hexed)))
244 }
245 DataType::Dictionary(_, value_type) => {
246 let dict = as_dictionary_array::<Int32Type>(&array);
247
248 let values = match **value_type {
249 DataType::Int64 => as_int64_array(dict.values())?
250 .iter()
251 .map(|v| v.map(hex_int64))
252 .collect::<Vec<_>>(),
253 DataType::Utf8 => as_string_array(dict.values())
254 .iter()
255 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
256 .collect::<Result<_, _>>()?,
257 DataType::Binary => as_binary_array(dict.values())?
258 .iter()
259 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
260 .collect::<Result<_, _>>()?,
261 _ => exec_err!(
262 "hex got an unexpected argument type: {:?}",
263 array.data_type()
264 )?,
265 };
266
267 let new_values: Vec<Option<String>> = dict
268 .keys()
269 .iter()
270 .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None))
271 .collect();
272
273 let string_array_values = StringArray::from(new_values);
274
275 Ok(ColumnarValue::Array(Arc::new(string_array_values)))
276 }
277 _ => exec_err!(
278 "hex got an unexpected argument type: {:?}",
279 array.data_type()
280 ),
281 },
282 _ => exec_err!("native hex does not support scalar values at this time"),
283 }
284}
285
286#[cfg(test)]
287mod test {
288 use std::sync::Arc;
289
290 use arrow::array::{Int64Array, StringArray};
291 use arrow::{
292 array::{
293 as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder,
294 StringBuilder, StringDictionaryBuilder,
295 },
296 datatypes::{Int32Type, Int64Type},
297 };
298 use datafusion_expr::ColumnarValue;
299
300 #[test]
301 fn test_dictionary_hex_utf8() {
302 let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
303 input_builder.append_value("hi");
304 input_builder.append_value("bye");
305 input_builder.append_null();
306 input_builder.append_value("rust");
307 let input = input_builder.finish();
308
309 let mut string_builder = StringBuilder::new();
310 string_builder.append_value("6869");
311 string_builder.append_value("627965");
312 string_builder.append_null();
313 string_builder.append_value("72757374");
314 let expected = string_builder.finish();
315
316 let columnar_value = ColumnarValue::Array(Arc::new(input));
317 let result = super::spark_hex(&[columnar_value]).unwrap();
318
319 let result = match result {
320 ColumnarValue::Array(array) => array,
321 _ => panic!("Expected array"),
322 };
323
324 let result = as_string_array(&result);
325
326 assert_eq!(result, &expected);
327 }
328
329 #[test]
330 fn test_dictionary_hex_int64() {
331 let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
332 input_builder.append_value(1);
333 input_builder.append_value(2);
334 input_builder.append_null();
335 input_builder.append_value(3);
336 let input = input_builder.finish();
337
338 let mut string_builder = StringBuilder::new();
339 string_builder.append_value("1");
340 string_builder.append_value("2");
341 string_builder.append_null();
342 string_builder.append_value("3");
343 let expected = string_builder.finish();
344
345 let columnar_value = ColumnarValue::Array(Arc::new(input));
346 let result = super::spark_hex(&[columnar_value]).unwrap();
347
348 let result = match result {
349 ColumnarValue::Array(array) => array,
350 _ => panic!("Expected array"),
351 };
352
353 let result = as_string_array(&result);
354
355 assert_eq!(result, &expected);
356 }
357
358 #[test]
359 fn test_dictionary_hex_binary() {
360 let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
361 input_builder.append_value("1");
362 input_builder.append_value("j");
363 input_builder.append_null();
364 input_builder.append_value("3");
365 let input = input_builder.finish();
366
367 let mut expected_builder = StringBuilder::new();
368 expected_builder.append_value("31");
369 expected_builder.append_value("6A");
370 expected_builder.append_null();
371 expected_builder.append_value("33");
372 let expected = expected_builder.finish();
373
374 let columnar_value = ColumnarValue::Array(Arc::new(input));
375 let result = super::spark_hex(&[columnar_value]).unwrap();
376
377 let result = match result {
378 ColumnarValue::Array(array) => array,
379 _ => panic!("Expected array"),
380 };
381
382 let result = as_string_array(&result);
383
384 assert_eq!(result, &expected);
385 }
386
387 #[test]
388 fn test_hex_int64() {
389 let num = 1234;
390 let hexed = super::hex_int64(num);
391 assert_eq!(hexed, "4D2".to_string());
392
393 let num = -1;
394 let hexed = super::hex_int64(num);
395 assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string());
396 }
397
398 #[test]
399 fn test_spark_hex_int64() {
400 let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
401 let columnar_value = ColumnarValue::Array(Arc::new(int_array));
402
403 let result = super::spark_hex(&[columnar_value]).unwrap();
404 let result = match result {
405 ColumnarValue::Array(array) => array,
406 _ => panic!("Expected array"),
407 };
408
409 let string_array = as_string_array(&result);
410 let expected_array = StringArray::from(vec![
411 Some("1".to_string()),
412 Some("2".to_string()),
413 None,
414 Some("3".to_string()),
415 ]);
416
417 assert_eq!(string_array, &expected_array);
418 }
419}