datafusion_spark/function/math/
hex.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{Array, StringArray};
22use arrow::datatypes::DataType;
23use arrow::{
24 array::{as_dictionary_array, as_largestring_array, as_string_array},
25 datatypes::Int32Type,
26};
27use datafusion_common::cast::as_large_binary_array;
28use datafusion_common::cast::as_string_view_array;
29use datafusion_common::types::{NativeType, logical_int64, logical_string};
30use datafusion_common::utils::take_function_args;
31use datafusion_common::{
32 DataFusionError,
33 cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
34 exec_err,
35};
36use datafusion_expr::{
37 Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
38 TypeSignatureClass, Volatility,
39};
40#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkHex {
43 signature: Signature,
44 aliases: Vec<String>,
45}
46
47impl Default for SparkHex {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl SparkHex {
54 pub fn new() -> Self {
55 let int64 = Coercion::new_implicit(
56 TypeSignatureClass::Native(logical_int64()),
57 vec![TypeSignatureClass::Numeric],
58 NativeType::Int64,
59 );
60
61 let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
62
63 let binary = Coercion::new_exact(TypeSignatureClass::Binary);
64
65 let variants = vec![
66 TypeSignature::Coercible(vec![int64]),
68 TypeSignature::Coercible(vec![string]),
70 TypeSignature::Coercible(vec![binary]),
72 ];
73
74 Self {
75 signature: Signature::one_of(variants, Volatility::Immutable),
76 aliases: vec![],
77 }
78 }
79}
80
81impl ScalarUDFImpl for SparkHex {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn name(&self) -> &str {
87 "hex"
88 }
89
90 fn signature(&self) -> &Signature {
91 &self.signature
92 }
93
94 fn return_type(
95 &self,
96 _arg_types: &[DataType],
97 ) -> datafusion_common::Result<DataType> {
98 Ok(DataType::Utf8)
99 }
100
101 fn invoke_with_args(
102 &self,
103 args: ScalarFunctionArgs,
104 ) -> datafusion_common::Result<ColumnarValue> {
105 spark_hex(&args.args)
106 }
107
108 fn aliases(&self) -> &[String] {
109 &self.aliases
110 }
111}
112
113fn hex_int64(num: i64) -> String {
114 format!("{num:X}")
115}
116
117const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef";
119const HEX_CHARS_UPPER: &[u8; 16] = b"0123456789ABCDEF";
120
121#[inline]
122fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
123 let bytes = data.as_ref();
124 let mut s = String::with_capacity(bytes.len() * 2);
125 let hex_chars = if lower_case {
126 HEX_CHARS_LOWER
127 } else {
128 HEX_CHARS_UPPER
129 };
130 for &b in bytes {
131 s.push(hex_chars[(b >> 4) as usize] as char);
132 s.push(hex_chars[(b & 0x0f) as usize] as char);
133 }
134 s
135}
136
137#[inline(always)]
138fn hex_bytes<T: AsRef<[u8]>>(
139 bytes: T,
140 lowercase: bool,
141) -> Result<String, std::fmt::Error> {
142 let hex_string = hex_encode(bytes, lowercase);
143 Ok(hex_string)
144}
145
146pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
148 compute_hex(args, false)
149}
150
151pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
153 compute_hex(args, true)
154}
155
156pub fn compute_hex(
157 args: &[ColumnarValue],
158 lowercase: bool,
159) -> Result<ColumnarValue, DataFusionError> {
160 let input = match take_function_args("hex", args)? {
161 [ColumnarValue::Scalar(value)] => ColumnarValue::Array(value.to_array()?),
162 [ColumnarValue::Array(arr)] => ColumnarValue::Array(Arc::clone(arr)),
163 };
164
165 match &input {
166 ColumnarValue::Array(array) => match array.data_type() {
167 DataType::Int64 => {
168 let array = as_int64_array(array)?;
169
170 let hexed_array: StringArray =
171 array.iter().map(|v| v.map(hex_int64)).collect();
172
173 Ok(ColumnarValue::Array(Arc::new(hexed_array)))
174 }
175 DataType::Utf8 => {
176 let array = as_string_array(array);
177
178 let hexed: StringArray = array
179 .iter()
180 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
181 .collect::<Result<_, _>>()?;
182
183 Ok(ColumnarValue::Array(Arc::new(hexed)))
184 }
185 DataType::Utf8View => {
186 let array = as_string_view_array(array)?;
187
188 let hexed: StringArray = array
189 .iter()
190 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
191 .collect::<Result<_, _>>()?;
192
193 Ok(ColumnarValue::Array(Arc::new(hexed)))
194 }
195 DataType::LargeUtf8 => {
196 let array = as_largestring_array(array);
197
198 let hexed: StringArray = array
199 .iter()
200 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
201 .collect::<Result<_, _>>()?;
202
203 Ok(ColumnarValue::Array(Arc::new(hexed)))
204 }
205 DataType::Binary => {
206 let array = as_binary_array(array)?;
207
208 let hexed: StringArray = array
209 .iter()
210 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
211 .collect::<Result<_, _>>()?;
212
213 Ok(ColumnarValue::Array(Arc::new(hexed)))
214 }
215 DataType::LargeBinary => {
216 let array = as_large_binary_array(array)?;
217
218 let hexed: StringArray = array
219 .iter()
220 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
221 .collect::<Result<_, _>>()?;
222
223 Ok(ColumnarValue::Array(Arc::new(hexed)))
224 }
225 DataType::FixedSizeBinary(_) => {
226 let array = as_fixed_size_binary_array(array)?;
227
228 let hexed: StringArray = array
229 .iter()
230 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
231 .collect::<Result<_, _>>()?;
232
233 Ok(ColumnarValue::Array(Arc::new(hexed)))
234 }
235 DataType::Dictionary(_, value_type) => {
236 let dict = as_dictionary_array::<Int32Type>(&array);
237
238 let values = match **value_type {
239 DataType::Int64 => as_int64_array(dict.values())?
240 .iter()
241 .map(|v| v.map(hex_int64))
242 .collect::<Vec<_>>(),
243 DataType::Utf8 => as_string_array(dict.values())
244 .iter()
245 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
246 .collect::<Result<_, _>>()?,
247 DataType::Binary => as_binary_array(dict.values())?
248 .iter()
249 .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
250 .collect::<Result<_, _>>()?,
251 _ => exec_err!(
252 "hex got an unexpected argument type: {}",
253 array.data_type()
254 )?,
255 };
256
257 let new_values: Vec<Option<String>> = dict
258 .keys()
259 .iter()
260 .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None))
261 .collect();
262
263 let string_array_values = StringArray::from(new_values);
264
265 Ok(ColumnarValue::Array(Arc::new(string_array_values)))
266 }
267 _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()),
268 },
269 _ => exec_err!("native hex does not support scalar values at this time"),
270 }
271}
272
273#[cfg(test)]
274mod test {
275 use std::sync::Arc;
276
277 use arrow::array::{Int64Array, StringArray};
278 use arrow::{
279 array::{
280 BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder,
281 StringDictionaryBuilder, as_string_array,
282 },
283 datatypes::{Int32Type, Int64Type},
284 };
285 use datafusion_expr::ColumnarValue;
286
287 #[test]
288 fn test_dictionary_hex_utf8() {
289 let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
290 input_builder.append_value("hi");
291 input_builder.append_value("bye");
292 input_builder.append_null();
293 input_builder.append_value("rust");
294 let input = input_builder.finish();
295
296 let mut string_builder = StringBuilder::new();
297 string_builder.append_value("6869");
298 string_builder.append_value("627965");
299 string_builder.append_null();
300 string_builder.append_value("72757374");
301 let expected = string_builder.finish();
302
303 let columnar_value = ColumnarValue::Array(Arc::new(input));
304 let result = super::spark_hex(&[columnar_value]).unwrap();
305
306 let result = match result {
307 ColumnarValue::Array(array) => array,
308 _ => panic!("Expected array"),
309 };
310
311 let result = as_string_array(&result);
312
313 assert_eq!(result, &expected);
314 }
315
316 #[test]
317 fn test_dictionary_hex_int64() {
318 let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
319 input_builder.append_value(1);
320 input_builder.append_value(2);
321 input_builder.append_null();
322 input_builder.append_value(3);
323 let input = input_builder.finish();
324
325 let mut string_builder = StringBuilder::new();
326 string_builder.append_value("1");
327 string_builder.append_value("2");
328 string_builder.append_null();
329 string_builder.append_value("3");
330 let expected = string_builder.finish();
331
332 let columnar_value = ColumnarValue::Array(Arc::new(input));
333 let result = super::spark_hex(&[columnar_value]).unwrap();
334
335 let result = match result {
336 ColumnarValue::Array(array) => array,
337 _ => panic!("Expected array"),
338 };
339
340 let result = as_string_array(&result);
341
342 assert_eq!(result, &expected);
343 }
344
345 #[test]
346 fn test_dictionary_hex_binary() {
347 let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
348 input_builder.append_value("1");
349 input_builder.append_value("j");
350 input_builder.append_null();
351 input_builder.append_value("3");
352 let input = input_builder.finish();
353
354 let mut expected_builder = StringBuilder::new();
355 expected_builder.append_value("31");
356 expected_builder.append_value("6A");
357 expected_builder.append_null();
358 expected_builder.append_value("33");
359 let expected = expected_builder.finish();
360
361 let columnar_value = ColumnarValue::Array(Arc::new(input));
362 let result = super::spark_hex(&[columnar_value]).unwrap();
363
364 let result = match result {
365 ColumnarValue::Array(array) => array,
366 _ => panic!("Expected array"),
367 };
368
369 let result = as_string_array(&result);
370
371 assert_eq!(result, &expected);
372 }
373
374 #[test]
375 fn test_hex_int64() {
376 let num = 1234;
377 let hexed = super::hex_int64(num);
378 assert_eq!(hexed, "4D2".to_string());
379
380 let num = -1;
381 let hexed = super::hex_int64(num);
382 assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string());
383 }
384
385 #[test]
386 fn test_spark_hex_int64() {
387 let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
388 let columnar_value = ColumnarValue::Array(Arc::new(int_array));
389
390 let result = super::spark_hex(&[columnar_value]).unwrap();
391 let result = match result {
392 ColumnarValue::Array(array) => array,
393 _ => panic!("Expected array"),
394 };
395
396 let string_array = as_string_array(&result);
397 let expected_array = StringArray::from(vec![
398 Some("1".to_string()),
399 Some("2".to_string()),
400 None,
401 Some("3".to_string()),
402 ]);
403
404 assert_eq!(string_array, &expected_array);
405 }
406}