datafusion_spark/function/math/
hex.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::function::error_utils::{
22 invalid_arg_count_exec_err, unsupported_data_type_exec_err,
23};
24use arrow::array::{Array, StringArray};
25use arrow::datatypes::DataType;
26use arrow::{
27 array::{as_dictionary_array, as_largestring_array, as_string_array},
28 datatypes::Int32Type,
29};
30use datafusion_common::cast::as_string_view_array;
31use datafusion_common::{
32 cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
33 exec_err, DataFusionError,
34};
35use datafusion_expr::Signature;
36use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
37use std::fmt::Write;
38
39#[derive(Debug, PartialEq, Eq, Hash)]
41pub struct SparkHex {
42 signature: Signature,
43 aliases: Vec<String>,
44}
45
46impl Default for SparkHex {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl SparkHex {
53 pub fn new() -> Self {
54 Self {
55 signature: Signature::user_defined(Volatility::Immutable),
56 aliases: vec![],
57 }
58 }
59}
60
61impl ScalarUDFImpl for SparkHex {
62 fn as_any(&self) -> &dyn Any {
63 self
64 }
65
66 fn name(&self) -> &str {
67 "hex"
68 }
69
70 fn signature(&self) -> &Signature {
71 &self.signature
72 }
73
74 fn return_type(
75 &self,
76 _arg_types: &[DataType],
77 ) -> datafusion_common::Result<DataType> {
78 Ok(DataType::Utf8)
79 }
80
81 fn invoke_with_args(
82 &self,
83 args: ScalarFunctionArgs,
84 ) -> datafusion_common::Result<ColumnarValue> {
85 spark_hex(&args.args)
86 }
87
88 fn aliases(&self) -> &[String] {
89 &self.aliases
90 }
91
92 fn coerce_types(
93 &self,
94 arg_types: &[DataType],
95 ) -> datafusion_common::Result<Vec<DataType>> {
96 if arg_types.len() != 1 {
97 return Err(invalid_arg_count_exec_err("hex", (1, 1), arg_types.len()));
98 }
99 match &arg_types[0] {
100 DataType::Int64
101 | DataType::Utf8
102 | DataType::Utf8View
103 | DataType::LargeUtf8
104 | DataType::Binary
105 | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]),
106 DataType::Dictionary(key_type, value_type) => match value_type.as_ref() {
107 DataType::Int64
108 | DataType::Utf8
109 | DataType::Utf8View
110 | DataType::LargeUtf8
111 | DataType::Binary
112 | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]),
113 other => {
114 if other.is_numeric() {
115 Ok(vec![DataType::Dictionary(
116 key_type.clone(),
117 Box::new(DataType::Int64),
118 )])
119 } else {
120 Err(unsupported_data_type_exec_err(
121 "hex",
122 "Numeric, String, or Binary",
123 &arg_types[0],
124 ))
125 }
126 }
127 },
128 other => {
129 if other.is_numeric() {
130 Ok(vec![DataType::Int64])
131 } else {
132 Err(unsupported_data_type_exec_err(
133 "hex",
134 "Numeric, String, or Binary",
135 &arg_types[0],
136 ))
137 }
138 }
139 }
140 }
141}
142
143fn hex_int64(num: i64) -> String {
144 format!("{num:X}")
145}
146
147#[inline(always)]
148fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
149 let mut s = String::with_capacity(data.as_ref().len() * 2);
150 if lower_case {
151 for b in data.as_ref() {
152 write!(&mut s, "{b:02x}").unwrap();
154 }
155 } else {
156 for b in data.as_ref() {
157 write!(&mut s, "{b:02X}").unwrap();
159 }
160 }
161 s
162}
163
164#[inline(always)]
165fn hex_bytes<T: AsRef<[u8]>>(
166 bytes: T,
167 lowercase: bool,
168) -> Result<String, std::fmt::Error> {
169 let hex_string = hex_encode(bytes, lowercase);
170 Ok(hex_string)
171}
172
173pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
175 compute_hex(args, false)
176}
177
178pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
180 compute_hex(args, true)
181}
182
183pub fn compute_hex(
184 args: &[ColumnarValue],
185 lowercase: bool,
186) -> Result<ColumnarValue, DataFusionError> {
187 if args.len() != 1 {
188 return Err(DataFusionError::Internal(
189 "hex expects exactly one argument".to_string(),
190 ));
191 }
192
193 let input = match &args[0] {
194 ColumnarValue::Scalar(value) => ColumnarValue::Array(value.to_array()?),
195 ColumnarValue::Array(_) => args[0].clone(),
196 };
197
198 match &input {
199 ColumnarValue::Array(array) => match array.data_type() {
200 DataType::Int64 => {
201 let array = as_int64_array(array)?;
202
203 let hexed_array: StringArray =
204 array.iter().map(|v| v.map(hex_int64)).collect();
205
206 Ok(ColumnarValue::Array(Arc::new(hexed_array)))
207 }
208 DataType::Utf8 => {
209 let array = as_string_array(array);
210
211 let hexed: StringArray = array
212 .iter()
213 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
214 .collect::<Result<_, _>>()?;
215
216 Ok(ColumnarValue::Array(Arc::new(hexed)))
217 }
218 DataType::Utf8View => {
219 let array = as_string_view_array(array)?;
220
221 let hexed: StringArray = array
222 .iter()
223 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
224 .collect::<Result<_, _>>()?;
225
226 Ok(ColumnarValue::Array(Arc::new(hexed)))
227 }
228 DataType::LargeUtf8 => {
229 let array = as_largestring_array(array);
230
231 let hexed: StringArray = array
232 .iter()
233 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
234 .collect::<Result<_, _>>()?;
235
236 Ok(ColumnarValue::Array(Arc::new(hexed)))
237 }
238 DataType::Binary => {
239 let array = as_binary_array(array)?;
240
241 let hexed: StringArray = array
242 .iter()
243 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
244 .collect::<Result<_, _>>()?;
245
246 Ok(ColumnarValue::Array(Arc::new(hexed)))
247 }
248 DataType::FixedSizeBinary(_) => {
249 let array = as_fixed_size_binary_array(array)?;
250
251 let hexed: StringArray = array
252 .iter()
253 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
254 .collect::<Result<_, _>>()?;
255
256 Ok(ColumnarValue::Array(Arc::new(hexed)))
257 }
258 DataType::Dictionary(_, value_type) => {
259 let dict = as_dictionary_array::<Int32Type>(&array);
260
261 let values = match **value_type {
262 DataType::Int64 => as_int64_array(dict.values())?
263 .iter()
264 .map(|v| v.map(hex_int64))
265 .collect::<Vec<_>>(),
266 DataType::Utf8 => as_string_array(dict.values())
267 .iter()
268 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
269 .collect::<Result<_, _>>()?,
270 DataType::Binary => as_binary_array(dict.values())?
271 .iter()
272 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
273 .collect::<Result<_, _>>()?,
274 _ => exec_err!(
275 "hex got an unexpected argument type: {:?}",
276 array.data_type()
277 )?,
278 };
279
280 let new_values: Vec<Option<String>> = dict
281 .keys()
282 .iter()
283 .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None))
284 .collect();
285
286 let string_array_values = StringArray::from(new_values);
287
288 Ok(ColumnarValue::Array(Arc::new(string_array_values)))
289 }
290 _ => exec_err!(
291 "hex got an unexpected argument type: {:?}",
292 array.data_type()
293 ),
294 },
295 _ => exec_err!("native hex does not support scalar values at this time"),
296 }
297}
298
299#[cfg(test)]
300mod test {
301 use std::sync::Arc;
302
303 use arrow::array::{Int64Array, StringArray};
304 use arrow::{
305 array::{
306 as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder,
307 StringBuilder, StringDictionaryBuilder,
308 },
309 datatypes::{Int32Type, Int64Type},
310 };
311 use datafusion_expr::ColumnarValue;
312
313 #[test]
314 fn test_dictionary_hex_utf8() {
315 let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
316 input_builder.append_value("hi");
317 input_builder.append_value("bye");
318 input_builder.append_null();
319 input_builder.append_value("rust");
320 let input = input_builder.finish();
321
322 let mut string_builder = StringBuilder::new();
323 string_builder.append_value("6869");
324 string_builder.append_value("627965");
325 string_builder.append_null();
326 string_builder.append_value("72757374");
327 let expected = string_builder.finish();
328
329 let columnar_value = ColumnarValue::Array(Arc::new(input));
330 let result = super::spark_hex(&[columnar_value]).unwrap();
331
332 let result = match result {
333 ColumnarValue::Array(array) => array,
334 _ => panic!("Expected array"),
335 };
336
337 let result = as_string_array(&result);
338
339 assert_eq!(result, &expected);
340 }
341
342 #[test]
343 fn test_dictionary_hex_int64() {
344 let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
345 input_builder.append_value(1);
346 input_builder.append_value(2);
347 input_builder.append_null();
348 input_builder.append_value(3);
349 let input = input_builder.finish();
350
351 let mut string_builder = StringBuilder::new();
352 string_builder.append_value("1");
353 string_builder.append_value("2");
354 string_builder.append_null();
355 string_builder.append_value("3");
356 let expected = string_builder.finish();
357
358 let columnar_value = ColumnarValue::Array(Arc::new(input));
359 let result = super::spark_hex(&[columnar_value]).unwrap();
360
361 let result = match result {
362 ColumnarValue::Array(array) => array,
363 _ => panic!("Expected array"),
364 };
365
366 let result = as_string_array(&result);
367
368 assert_eq!(result, &expected);
369 }
370
371 #[test]
372 fn test_dictionary_hex_binary() {
373 let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
374 input_builder.append_value("1");
375 input_builder.append_value("j");
376 input_builder.append_null();
377 input_builder.append_value("3");
378 let input = input_builder.finish();
379
380 let mut expected_builder = StringBuilder::new();
381 expected_builder.append_value("31");
382 expected_builder.append_value("6A");
383 expected_builder.append_null();
384 expected_builder.append_value("33");
385 let expected = expected_builder.finish();
386
387 let columnar_value = ColumnarValue::Array(Arc::new(input));
388 let result = super::spark_hex(&[columnar_value]).unwrap();
389
390 let result = match result {
391 ColumnarValue::Array(array) => array,
392 _ => panic!("Expected array"),
393 };
394
395 let result = as_string_array(&result);
396
397 assert_eq!(result, &expected);
398 }
399
400 #[test]
401 fn test_hex_int64() {
402 let num = 1234;
403 let hexed = super::hex_int64(num);
404 assert_eq!(hexed, "4D2".to_string());
405
406 let num = -1;
407 let hexed = super::hex_int64(num);
408 assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string());
409 }
410
411 #[test]
412 fn test_spark_hex_int64() {
413 let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
414 let columnar_value = ColumnarValue::Array(Arc::new(int_array));
415
416 let result = super::spark_hex(&[columnar_value]).unwrap();
417 let result = match result {
418 ColumnarValue::Array(array) => array,
419 _ => panic!("Expected array"),
420 };
421
422 let string_array = as_string_array(&result);
423 let expected_array = StringArray::from(vec![
424 Some("1".to_string()),
425 Some("2".to_string()),
426 None,
427 Some("3".to_string()),
428 ]);
429
430 assert_eq!(string_array, &expected_array);
431 }
432}