datafusion_comet_spark_expr/math_funcs/
unhex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::OffsetSizeTrait;
21use arrow::datatypes::DataType;
22use datafusion::common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue};
23use datafusion::logical_expr::ColumnarValue;
24
25/// Helper function to convert a hex digit to a binary value.
26fn unhex_digit(c: u8) -> Result<u8, DataFusionError> {
27    match c {
28        b'0'..=b'9' => Ok(c - b'0'),
29        b'A'..=b'F' => Ok(10 + c - b'A'),
30        b'a'..=b'f' => Ok(10 + c - b'a'),
31        _ => Err(DataFusionError::Execution(
32            "Input to unhex_digit is not a valid hex digit".to_string(),
33        )),
34    }
35}
36
37/// Convert a hex string to binary and store the result in `result`. Returns an error if the input
38/// is not a valid hex string.
39fn unhex(hex_str: &str, result: &mut Vec<u8>) -> Result<(), DataFusionError> {
40    let bytes = hex_str.as_bytes();
41
42    let mut i = 0;
43
44    if (bytes.len() & 0x01) != 0 {
45        let v = unhex_digit(bytes[0])?;
46
47        result.push(v);
48        i += 1;
49    }
50
51    while i < bytes.len() {
52        let first = unhex_digit(bytes[i])?;
53        let second = unhex_digit(bytes[i + 1])?;
54        result.push((first << 4) | second);
55
56        i += 2;
57    }
58
59    Ok(())
60}
61
62fn spark_unhex_inner<T: OffsetSizeTrait>(
63    array: &ColumnarValue,
64    fail_on_error: bool,
65) -> Result<ColumnarValue, DataFusionError> {
66    match array {
67        ColumnarValue::Array(array) => {
68            let string_array = as_generic_string_array::<T>(array)?;
69
70            let mut encoded = Vec::new();
71            let mut builder = arrow::array::BinaryBuilder::new();
72
73            for item in string_array.iter() {
74                if let Some(s) = item {
75                    if unhex(s, &mut encoded).is_ok() {
76                        builder.append_value(encoded.as_slice());
77                    } else if fail_on_error {
78                        return exec_err!("Input to unhex is not a valid hex string: {s}");
79                    } else {
80                        builder.append_null();
81                    }
82                    encoded.clear();
83                } else {
84                    builder.append_null();
85                }
86            }
87            Ok(ColumnarValue::Array(Arc::new(builder.finish())))
88        }
89        ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))) => {
90            let mut encoded = Vec::new();
91
92            if unhex(string, &mut encoded).is_ok() {
93                Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encoded))))
94            } else if fail_on_error {
95                exec_err!("Input to unhex is not a valid hex string: {string}")
96            } else {
97                Ok(ColumnarValue::Scalar(ScalarValue::Binary(None)))
98            }
99        }
100        ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {
101            Ok(ColumnarValue::Scalar(ScalarValue::Binary(None)))
102        }
103        _ => {
104            exec_err!(
105                "The first argument must be a string scalar or array, but got: {:?}",
106                array
107            )
108        }
109    }
110}
111
112/// Spark-compatible `unhex` expression
113pub fn spark_unhex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
114    if args.len() > 2 {
115        return exec_err!("unhex takes at most 2 arguments, but got: {}", args.len());
116    }
117
118    let val_to_unhex = &args[0];
119    let fail_on_error = if args.len() == 2 {
120        match &args[1] {
121            ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error,
122            _ => {
123                return exec_err!(
124                    "The second argument must be boolean scalar, but got: {:?}",
125                    args[1]
126                );
127            }
128        }
129    } else {
130        false
131    };
132
133    match val_to_unhex.data_type() {
134        DataType::Utf8 => spark_unhex_inner::<i32>(val_to_unhex, fail_on_error),
135        DataType::LargeUtf8 => spark_unhex_inner::<i64>(val_to_unhex, fail_on_error),
136        other => exec_err!(
137            "The first argument must be a Utf8 or LargeUtf8: {:?}",
138            other
139        ),
140    }
141}
142
143#[cfg(test)]
144mod test {
145    use std::sync::Arc;
146
147    use arrow::array::make_array;
148    use arrow::array::ArrayData;
149    use arrow::array::{BinaryBuilder, StringBuilder};
150    use datafusion::common::ScalarValue;
151    use datafusion::logical_expr::ColumnarValue;
152
153    use super::unhex;
154
155    #[test]
156    fn test_spark_unhex_null() -> Result<(), Box<dyn std::error::Error>> {
157        let input = ArrayData::new_null(&arrow::datatypes::DataType::Utf8, 2);
158        let output = ArrayData::new_null(&arrow::datatypes::DataType::Binary, 2);
159
160        let input = ColumnarValue::Array(Arc::new(make_array(input)));
161        let expected = ColumnarValue::Array(Arc::new(make_array(output)));
162
163        let result = super::spark_unhex(&[input])?;
164
165        match (result, expected) {
166            (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => {
167                assert_eq!(*result, *expected);
168                Ok(())
169            }
170            _ => Err("Unexpected result type".into()),
171        }
172    }
173
174    #[test]
175    fn test_partial_error() -> Result<(), Box<dyn std::error::Error>> {
176        let mut input = StringBuilder::new();
177
178        input.append_value("1CGG"); // 1C is ok, but GG is invalid
179        input.append_value("537061726B2053514C"); // followed by valid
180
181        let input = ColumnarValue::Array(Arc::new(input.finish()));
182        let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));
183
184        let result = super::spark_unhex(&[input, fail_on_error])?;
185
186        let mut expected = BinaryBuilder::new();
187        expected.append_null();
188        expected.append_value("Spark SQL".as_bytes());
189
190        match (result, ColumnarValue::Array(Arc::new(expected.finish()))) {
191            (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => {
192                assert_eq!(*result, *expected);
193
194                Ok(())
195            }
196            _ => Err("Unexpected result type".into()),
197        }
198    }
199
200    #[test]
201    fn test_unhex_valid() -> Result<(), Box<dyn std::error::Error>> {
202        let mut result = Vec::new();
203
204        unhex("537061726B2053514C", &mut result)?;
205        let result_str = std::str::from_utf8(&result)?;
206        assert_eq!(result_str, "Spark SQL");
207        result.clear();
208
209        unhex("1C", &mut result)?;
210        assert_eq!(result, vec![28]);
211        result.clear();
212
213        unhex("737472696E67", &mut result)?;
214        assert_eq!(result, "string".as_bytes());
215        result.clear();
216
217        unhex("1", &mut result)?;
218        assert_eq!(result, vec![1]);
219        result.clear();
220
221        Ok(())
222    }
223
224    #[test]
225    fn test_odd_length() -> Result<(), Box<dyn std::error::Error>> {
226        let mut result = Vec::new();
227
228        unhex("A1B", &mut result)?;
229        assert_eq!(result, vec![10, 27]);
230        result.clear();
231
232        unhex("0A1B", &mut result)?;
233        assert_eq!(result, vec![10, 27]);
234        result.clear();
235
236        Ok(())
237    }
238
239    #[test]
240    fn test_unhex_empty() {
241        let mut result = Vec::new();
242
243        // Empty hex string
244        unhex("", &mut result).unwrap();
245        assert!(result.is_empty());
246    }
247
248    #[test]
249    fn test_unhex_invalid() {
250        let mut result = Vec::new();
251
252        // Invalid hex strings
253        assert!(unhex("##", &mut result).is_err());
254        assert!(unhex("G123", &mut result).is_err());
255        assert!(unhex("hello", &mut result).is_err());
256        assert!(unhex("\0", &mut result).is_err());
257    }
258}