datafusion_comet_spark_expr/math_funcs/
unhex.rs1use std::sync::Arc;
19
20use arrow::array::OffsetSizeTrait;
21use arrow::datatypes::DataType;
22use datafusion::common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue};
23use datafusion::logical_expr::ColumnarValue;
24
25fn unhex_digit(c: u8) -> Result<u8, DataFusionError> {
27 match c {
28 b'0'..=b'9' => Ok(c - b'0'),
29 b'A'..=b'F' => Ok(10 + c - b'A'),
30 b'a'..=b'f' => Ok(10 + c - b'a'),
31 _ => Err(DataFusionError::Execution(
32 "Input to unhex_digit is not a valid hex digit".to_string(),
33 )),
34 }
35}
36
37fn unhex(hex_str: &str, result: &mut Vec<u8>) -> Result<(), DataFusionError> {
40 let bytes = hex_str.as_bytes();
41
42 let mut i = 0;
43
44 if (bytes.len() & 0x01) != 0 {
45 let v = unhex_digit(bytes[0])?;
46
47 result.push(v);
48 i += 1;
49 }
50
51 while i < bytes.len() {
52 let first = unhex_digit(bytes[i])?;
53 let second = unhex_digit(bytes[i + 1])?;
54 result.push((first << 4) | second);
55
56 i += 2;
57 }
58
59 Ok(())
60}
61
62fn spark_unhex_inner<T: OffsetSizeTrait>(
63 array: &ColumnarValue,
64 fail_on_error: bool,
65) -> Result<ColumnarValue, DataFusionError> {
66 match array {
67 ColumnarValue::Array(array) => {
68 let string_array = as_generic_string_array::<T>(array)?;
69
70 let mut encoded = Vec::new();
71 let mut builder = arrow::array::BinaryBuilder::new();
72
73 for item in string_array.iter() {
74 if let Some(s) = item {
75 if unhex(s, &mut encoded).is_ok() {
76 builder.append_value(encoded.as_slice());
77 } else if fail_on_error {
78 return exec_err!("Input to unhex is not a valid hex string: {s}");
79 } else {
80 builder.append_null();
81 }
82 encoded.clear();
83 } else {
84 builder.append_null();
85 }
86 }
87 Ok(ColumnarValue::Array(Arc::new(builder.finish())))
88 }
89 ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))) => {
90 let mut encoded = Vec::new();
91
92 if unhex(string, &mut encoded).is_ok() {
93 Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encoded))))
94 } else if fail_on_error {
95 exec_err!("Input to unhex is not a valid hex string: {string}")
96 } else {
97 Ok(ColumnarValue::Scalar(ScalarValue::Binary(None)))
98 }
99 }
100 ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {
101 Ok(ColumnarValue::Scalar(ScalarValue::Binary(None)))
102 }
103 _ => {
104 exec_err!(
105 "The first argument must be a string scalar or array, but got: {:?}",
106 array
107 )
108 }
109 }
110}
111
112pub fn spark_unhex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
114 if args.len() > 2 {
115 return exec_err!("unhex takes at most 2 arguments, but got: {}", args.len());
116 }
117
118 let val_to_unhex = &args[0];
119 let fail_on_error = if args.len() == 2 {
120 match &args[1] {
121 ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error,
122 _ => {
123 return exec_err!(
124 "The second argument must be boolean scalar, but got: {:?}",
125 args[1]
126 );
127 }
128 }
129 } else {
130 false
131 };
132
133 match val_to_unhex.data_type() {
134 DataType::Utf8 => spark_unhex_inner::<i32>(val_to_unhex, fail_on_error),
135 DataType::LargeUtf8 => spark_unhex_inner::<i64>(val_to_unhex, fail_on_error),
136 other => exec_err!(
137 "The first argument must be a Utf8 or LargeUtf8: {:?}",
138 other
139 ),
140 }
141}
142
143#[cfg(test)]
144mod test {
145 use std::sync::Arc;
146
147 use arrow::array::make_array;
148 use arrow::array::ArrayData;
149 use arrow::array::{BinaryBuilder, StringBuilder};
150 use datafusion::common::ScalarValue;
151 use datafusion::logical_expr::ColumnarValue;
152
153 use super::unhex;
154
155 #[test]
156 fn test_spark_unhex_null() -> Result<(), Box<dyn std::error::Error>> {
157 let input = ArrayData::new_null(&arrow::datatypes::DataType::Utf8, 2);
158 let output = ArrayData::new_null(&arrow::datatypes::DataType::Binary, 2);
159
160 let input = ColumnarValue::Array(Arc::new(make_array(input)));
161 let expected = ColumnarValue::Array(Arc::new(make_array(output)));
162
163 let result = super::spark_unhex(&[input])?;
164
165 match (result, expected) {
166 (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => {
167 assert_eq!(*result, *expected);
168 Ok(())
169 }
170 _ => Err("Unexpected result type".into()),
171 }
172 }
173
174 #[test]
175 fn test_partial_error() -> Result<(), Box<dyn std::error::Error>> {
176 let mut input = StringBuilder::new();
177
178 input.append_value("1CGG"); input.append_value("537061726B2053514C"); let input = ColumnarValue::Array(Arc::new(input.finish()));
182 let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));
183
184 let result = super::spark_unhex(&[input, fail_on_error])?;
185
186 let mut expected = BinaryBuilder::new();
187 expected.append_null();
188 expected.append_value("Spark SQL".as_bytes());
189
190 match (result, ColumnarValue::Array(Arc::new(expected.finish()))) {
191 (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => {
192 assert_eq!(*result, *expected);
193
194 Ok(())
195 }
196 _ => Err("Unexpected result type".into()),
197 }
198 }
199
200 #[test]
201 fn test_unhex_valid() -> Result<(), Box<dyn std::error::Error>> {
202 let mut result = Vec::new();
203
204 unhex("537061726B2053514C", &mut result)?;
205 let result_str = std::str::from_utf8(&result)?;
206 assert_eq!(result_str, "Spark SQL");
207 result.clear();
208
209 unhex("1C", &mut result)?;
210 assert_eq!(result, vec![28]);
211 result.clear();
212
213 unhex("737472696E67", &mut result)?;
214 assert_eq!(result, "string".as_bytes());
215 result.clear();
216
217 unhex("1", &mut result)?;
218 assert_eq!(result, vec![1]);
219 result.clear();
220
221 Ok(())
222 }
223
224 #[test]
225 fn test_odd_length() -> Result<(), Box<dyn std::error::Error>> {
226 let mut result = Vec::new();
227
228 unhex("A1B", &mut result)?;
229 assert_eq!(result, vec![10, 27]);
230 result.clear();
231
232 unhex("0A1B", &mut result)?;
233 assert_eq!(result, vec![10, 27]);
234 result.clear();
235
236 Ok(())
237 }
238
239 #[test]
240 fn test_unhex_empty() {
241 let mut result = Vec::new();
242
243 unhex("", &mut result).unwrap();
245 assert!(result.is_empty());
246 }
247
248 #[test]
249 fn test_unhex_invalid() {
250 let mut result = Vec::new();
251
252 assert!(unhex("##", &mut result).is_err());
254 assert!(unhex("G123", &mut result).is_err());
255 assert!(unhex("hello", &mut result).is_err());
256 assert!(unhex("\0", &mut result).is_err());
257 }
258}