1use std::str::from_utf8_unchecked;
19use std::sync::Arc;
20
21use arrow::array::{Array, ArrayRef, StringBuilder};
22use arrow::datatypes::DataType;
23use arrow::{
24 array::{as_dictionary_array, as_largestring_array, as_string_array},
25 datatypes::Int32Type,
26};
27use datafusion_common::cast::as_large_binary_array;
28use datafusion_common::cast::as_string_view_array;
29use datafusion_common::types::{NativeType, logical_int64, logical_string};
30use datafusion_common::utils::take_function_args;
31use datafusion_common::{
32 DataFusionError,
33 cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
34 exec_err,
35};
36use datafusion_expr::{
37 Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
38 TypeSignatureClass, Volatility,
39};
40#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkHex {
43 signature: Signature,
44 aliases: Vec<String>,
45}
46
47impl Default for SparkHex {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl SparkHex {
54 pub fn new() -> Self {
55 let int64 = Coercion::new_implicit(
56 TypeSignatureClass::Native(logical_int64()),
57 vec![TypeSignatureClass::Numeric],
58 NativeType::Int64,
59 );
60
61 let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
62
63 let binary = Coercion::new_exact(TypeSignatureClass::Binary);
64
65 let variants = vec![
66 TypeSignature::Coercible(vec![int64]),
68 TypeSignature::Coercible(vec![string]),
70 TypeSignature::Coercible(vec![binary]),
72 ];
73
74 Self {
75 signature: Signature::one_of(variants, Volatility::Immutable),
76 aliases: vec![],
77 }
78 }
79}
80
81impl ScalarUDFImpl for SparkHex {
82 fn name(&self) -> &str {
83 "hex"
84 }
85
86 fn signature(&self) -> &Signature {
87 &self.signature
88 }
89
90 fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
91 Ok(match &arg_types[0] {
92 DataType::Dictionary(key_type, _) => {
93 DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8))
94 }
95 _ => DataType::Utf8,
96 })
97 }
98
99 fn invoke_with_args(
100 &self,
101 args: ScalarFunctionArgs,
102 ) -> datafusion_common::Result<ColumnarValue> {
103 spark_hex(&args.args)
104 }
105
106 fn aliases(&self) -> &[String] {
107 &self.aliases
108 }
109}
110
111const HEX_CHARS_UPPER_NIBBLES: &[u8; 16] = b"0123456789ABCDEF";
117const HEX_CHARS_LOWER_NIBBLES: &[u8; 16] = b"0123456789abcdef";
118
119const HEX_LOOKUP_UPPER: [[u8; 2]; 256] = build_hex_lookup(HEX_CHARS_UPPER_NIBBLES);
120const HEX_LOOKUP_LOWER: [[u8; 2]; 256] = build_hex_lookup(HEX_CHARS_LOWER_NIBBLES);
121
122const fn build_hex_lookup(nibbles: &[u8; 16]) -> [[u8; 2]; 256] {
123 let mut table = [[0u8; 2]; 256];
124 let mut i = 0;
125 while i < 256 {
126 table[i][0] = nibbles[(i >> 4) & 0xF];
127 table[i][1] = nibbles[i & 0xF];
128 i += 1;
129 }
130 table
131}
132
133#[inline]
134fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] {
135 if num == 0 {
136 return b"0";
137 }
138
139 let mut n = num as u64;
143 let mut i = 16;
144 while n >= 0x10 {
145 i -= 2;
146 let pair = HEX_LOOKUP_UPPER[(n & 0xFF) as usize];
147 buffer[i] = pair[0];
148 buffer[i + 1] = pair[1];
149 n >>= 8;
150 }
151 if n > 0 {
152 i -= 1;
154 buffer[i] = HEX_CHARS_UPPER_NIBBLES[n as usize];
155 }
156 &buffer[i..]
157}
158
159fn hex_encode_bytes<'a, I, T>(
161 iter: I,
162 lowercase: bool,
163 len: usize,
164) -> Result<ArrayRef, DataFusionError>
165where
166 I: Iterator<Item = Option<T>>,
167 T: AsRef<[u8]> + 'a,
168{
169 let mut builder = StringBuilder::with_capacity(len, len * 64);
170 let mut buffer = Vec::with_capacity(64);
171 let lookup = if lowercase {
172 &HEX_LOOKUP_LOWER
173 } else {
174 &HEX_LOOKUP_UPPER
175 };
176
177 for v in iter {
178 if let Some(b) = v {
179 let bytes = b.as_ref();
180 buffer.clear();
181 buffer.reserve(bytes.len() * 2);
182 for &byte in bytes {
183 buffer.extend_from_slice(&lookup[byte as usize]);
184 }
185 unsafe {
187 builder.append_value(from_utf8_unchecked(&buffer));
188 }
189 } else {
190 builder.append_null();
191 }
192 }
193
194 Ok(Arc::new(builder.finish()))
195}
196
197fn hex_encode_int64(
199 iter: impl Iterator<Item = Option<i64>>,
200 len: usize,
201) -> Result<ArrayRef, DataFusionError> {
202 let mut builder = StringBuilder::with_capacity(len, len * 16);
203
204 for v in iter {
205 if let Some(num) = v {
206 let mut temp = [0u8; 16];
207 let slice = hex_int64(num, &mut temp);
208 unsafe {
210 builder.append_value(from_utf8_unchecked(slice));
211 }
212 } else {
213 builder.append_null();
214 }
215 }
216
217 Ok(Arc::new(builder.finish()))
218}
219
220pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
222 compute_hex(args, false)
223}
224
225pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
227 compute_hex(args, true)
228}
229
230pub fn compute_hex(
231 args: &[ColumnarValue],
232 lowercase: bool,
233) -> Result<ColumnarValue, DataFusionError> {
234 let input = match take_function_args("hex", args)? {
235 [ColumnarValue::Scalar(value)] => ColumnarValue::Array(value.to_array()?),
236 [ColumnarValue::Array(arr)] => ColumnarValue::Array(Arc::clone(arr)),
237 };
238
239 match &input {
240 ColumnarValue::Array(array) => match array.data_type() {
241 DataType::Int64 => {
242 let array = as_int64_array(array)?;
243 Ok(ColumnarValue::Array(hex_encode_int64(
244 array.iter(),
245 array.len(),
246 )?))
247 }
248 DataType::Utf8 => {
249 let array = as_string_array(array);
250 Ok(ColumnarValue::Array(hex_encode_bytes(
251 array.iter(),
252 lowercase,
253 array.len(),
254 )?))
255 }
256 DataType::Utf8View => {
257 let array = as_string_view_array(array)?;
258 Ok(ColumnarValue::Array(hex_encode_bytes(
259 array.iter(),
260 lowercase,
261 array.len(),
262 )?))
263 }
264 DataType::LargeUtf8 => {
265 let array = as_largestring_array(array);
266 Ok(ColumnarValue::Array(hex_encode_bytes(
267 array.iter(),
268 lowercase,
269 array.len(),
270 )?))
271 }
272 DataType::Binary => {
273 let array = as_binary_array(array)?;
274 Ok(ColumnarValue::Array(hex_encode_bytes(
275 array.iter(),
276 lowercase,
277 array.len(),
278 )?))
279 }
280 DataType::LargeBinary => {
281 let array = as_large_binary_array(array)?;
282 Ok(ColumnarValue::Array(hex_encode_bytes(
283 array.iter(),
284 lowercase,
285 array.len(),
286 )?))
287 }
288 DataType::FixedSizeBinary(_) => {
289 let array = as_fixed_size_binary_array(array)?;
290 Ok(ColumnarValue::Array(hex_encode_bytes(
291 array.iter(),
292 lowercase,
293 array.len(),
294 )?))
295 }
296 DataType::Dictionary(key_type, _) => {
297 if **key_type != DataType::Int32 {
298 return exec_err!(
299 "hex only supports Int32 dictionary keys, get: {}",
300 key_type
301 );
302 }
303
304 let dict = as_dictionary_array::<Int32Type>(&array);
305 let dict_values = dict.values();
306
307 let encoded_values = match dict_values.data_type() {
308 DataType::Int64 => {
309 let arr = as_int64_array(dict_values)?;
310 hex_encode_int64(arr.iter(), arr.len())?
311 }
312 DataType::Utf8 => {
313 let arr = as_string_array(dict_values);
314 hex_encode_bytes(arr.iter(), lowercase, arr.len())?
315 }
316 DataType::LargeUtf8 => {
317 let arr = as_largestring_array(dict_values);
318 hex_encode_bytes(arr.iter(), lowercase, arr.len())?
319 }
320 DataType::Utf8View => {
321 let arr = as_string_view_array(dict_values)?;
322 hex_encode_bytes(arr.iter(), lowercase, arr.len())?
323 }
324 DataType::Binary => {
325 let arr = as_binary_array(dict_values)?;
326 hex_encode_bytes(arr.iter(), lowercase, arr.len())?
327 }
328 DataType::LargeBinary => {
329 let arr = as_large_binary_array(dict_values)?;
330 hex_encode_bytes(arr.iter(), lowercase, arr.len())?
331 }
332 DataType::FixedSizeBinary(_) => {
333 let arr = as_fixed_size_binary_array(dict_values)?;
334 hex_encode_bytes(arr.iter(), lowercase, arr.len())?
335 }
336 _ => {
337 return exec_err!(
338 "hex got an unexpected argument type: {}",
339 dict_values.data_type()
340 );
341 }
342 };
343
344 let new_dict = dict.with_values(encoded_values);
345 Ok(ColumnarValue::Array(Arc::new(new_dict)))
346 }
347 _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()),
348 },
349 _ => exec_err!("native hex does not support scalar values at this time"),
350 }
351}
352
353#[cfg(test)]
354mod test {
355 use std::str::from_utf8_unchecked;
356 use std::sync::Arc;
357
358 use arrow::array::{
359 BinaryArray, DictionaryArray, Int32Array, Int64Array, StringArray,
360 };
361 use arrow::{
362 array::{
363 BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder,
364 as_string_array,
365 },
366 datatypes::{Int32Type, Int64Type},
367 };
368 use datafusion_common::cast::as_dictionary_array;
369 use datafusion_expr::ColumnarValue;
370
371 #[test]
372 fn test_dictionary_hex_utf8() {
373 let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
374 input_builder.append_value("hi");
375 input_builder.append_value("bye");
376 input_builder.append_null();
377 input_builder.append_value("rust");
378 let input = input_builder.finish();
379
380 let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
381 expected_builder.append_value("6869");
382 expected_builder.append_value("627965");
383 expected_builder.append_null();
384 expected_builder.append_value("72757374");
385 let expected = expected_builder.finish();
386
387 let columnar_value = ColumnarValue::Array(Arc::new(input));
388 let result = super::spark_hex(&[columnar_value]).unwrap();
389
390 let result = match result {
391 ColumnarValue::Array(array) => array,
392 _ => panic!("Expected array"),
393 };
394
395 let result = as_dictionary_array(&result).unwrap();
396
397 assert_eq!(result, &expected);
398 }
399
400 #[test]
401 fn test_dictionary_hex_int64() {
402 let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
403 input_builder.append_value(1);
404 input_builder.append_value(2);
405 input_builder.append_null();
406 input_builder.append_value(3);
407 let input = input_builder.finish();
408
409 let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
410 expected_builder.append_value("1");
411 expected_builder.append_value("2");
412 expected_builder.append_null();
413 expected_builder.append_value("3");
414 let expected = expected_builder.finish();
415
416 let columnar_value = ColumnarValue::Array(Arc::new(input));
417 let result = super::spark_hex(&[columnar_value]).unwrap();
418
419 let result = match result {
420 ColumnarValue::Array(array) => array,
421 _ => panic!("Expected array"),
422 };
423
424 let result = as_dictionary_array(&result).unwrap();
425
426 assert_eq!(result, &expected);
427 }
428
429 #[test]
430 fn test_dictionary_hex_binary() {
431 let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
432 input_builder.append_value("1");
433 input_builder.append_value("j");
434 input_builder.append_null();
435 input_builder.append_value("3");
436 let input = input_builder.finish();
437
438 let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
439 expected_builder.append_value("31");
440 expected_builder.append_value("6A");
441 expected_builder.append_null();
442 expected_builder.append_value("33");
443 let expected = expected_builder.finish();
444
445 let columnar_value = ColumnarValue::Array(Arc::new(input));
446 let result = super::spark_hex(&[columnar_value]).unwrap();
447
448 let result = match result {
449 ColumnarValue::Array(array) => array,
450 _ => panic!("Expected array"),
451 };
452
453 let result = as_dictionary_array(&result).unwrap();
454
455 assert_eq!(result, &expected);
456 }
457
458 #[test]
459 fn test_hex_int64() {
460 let test_cases = vec![
461 (0_i64, "0"),
462 (1, "1"),
463 (15, "F"),
464 (16, "10"),
465 (255, "FF"),
466 (256, "100"),
467 (1234, "4D2"),
468 (i64::MAX, "7FFFFFFFFFFFFFFF"),
469 (i64::MIN, "8000000000000000"),
470 (-1, "FFFFFFFFFFFFFFFF"),
471 ];
472
473 for (num, expected) in test_cases {
474 let mut cache = [0u8; 16];
475 let slice = super::hex_int64(num, &mut cache);
476
477 unsafe {
478 let result = from_utf8_unchecked(slice);
479 assert_eq!(expected, result, "hex_int64({num}) mismatch");
480 }
481 }
482 }
483
484 #[test]
485 fn test_hex_lookup_table_covers_all_bytes() {
486 for byte in 0u8..=255 {
489 let upper = format!("{byte:02X}");
490 let lower = format!("{byte:02x}");
491 let upper_pair = super::HEX_LOOKUP_UPPER[byte as usize];
492 let lower_pair = super::HEX_LOOKUP_LOWER[byte as usize];
493 assert_eq!(
494 upper.as_bytes(),
495 &upper_pair,
496 "upper encoding mismatch for byte 0x{byte:02X}"
497 );
498 assert_eq!(
499 lower.as_bytes(),
500 &lower_pair,
501 "lower encoding mismatch for byte 0x{byte:02X}"
502 );
503 }
504 }
505
506 #[test]
507 fn test_spark_hex_binary_round_trip_all_bytes() {
508 let payload: Vec<u8> = (0u8..=255).collect();
511 let bin_array = BinaryArray::from(vec![Some(payload.as_slice())]);
512
513 let result =
514 super::spark_hex(&[ColumnarValue::Array(Arc::new(bin_array))]).unwrap();
515 let array = match result {
516 ColumnarValue::Array(array) => array,
517 _ => panic!("Expected array"),
518 };
519 let strings = as_string_array(&array);
520 let mut expected = String::with_capacity(512);
521 for byte in 0u8..=255 {
522 use std::fmt::Write;
523 write!(expected, "{byte:02X}").unwrap();
524 }
525 assert_eq!(strings.value(0), expected);
526 }
527
528 #[test]
529 fn test_spark_hex_int64() {
530 let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
531 let columnar_value = ColumnarValue::Array(Arc::new(int_array));
532
533 let result = super::spark_hex(&[columnar_value]).unwrap();
534 let result = match result {
535 ColumnarValue::Array(array) => array,
536 _ => panic!("Expected array"),
537 };
538
539 let string_array = as_string_array(&result);
540 let expected_array = StringArray::from(vec![
541 Some("1".to_string()),
542 Some("2".to_string()),
543 None,
544 Some("3".to_string()),
545 ]);
546
547 assert_eq!(string_array, &expected_array);
548 }
549
550 #[test]
551 fn test_dict_values_null() {
552 let keys = Int32Array::from(vec![Some(0), None, Some(1)]);
553 let vals = Int64Array::from(vec![Some(32), None]);
554 let dict = DictionaryArray::new(keys, Arc::new(vals));
556
557 let columnar_value = ColumnarValue::Array(Arc::new(dict));
558 let result = super::spark_hex(&[columnar_value]).unwrap();
559
560 let result = match result {
561 ColumnarValue::Array(array) => array,
562 _ => panic!("Expected array"),
563 };
564
565 let result = as_dictionary_array(&result).unwrap();
566
567 let keys = Int32Array::from(vec![Some(0), None, Some(1)]);
568 let vals = StringArray::from(vec![Some("20"), None]);
569 let expected = DictionaryArray::new(keys, Arc::new(vals));
570
571 assert_eq!(&expected, result);
572 }
573}