Skip to main content

datafusion_spark/function/math/
unhex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, ArrayRef, BinaryBuilder};
19use arrow::datatypes::DataType;
20use datafusion_common::cast::{
21    as_large_string_array, as_string_array, as_string_view_array,
22};
23use datafusion_common::types::logical_string;
24use datafusion_common::utils::take_function_args;
25use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err};
26use datafusion_expr::{
27    Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28    TypeSignatureClass, Volatility,
29};
30use std::any::Any;
31use std::sync::Arc;
32
33/// <https://spark.apache.org/docs/latest/api/sql/index.html#unhex>
34#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkUnhex {
36    signature: Signature,
37}
38
39impl Default for SparkUnhex {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl SparkUnhex {
46    pub fn new() -> Self {
47        let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
48
49        Self {
50            signature: Signature::coercible(vec![string], Volatility::Immutable),
51        }
52    }
53}
54
55impl ScalarUDFImpl for SparkUnhex {
56    fn as_any(&self) -> &dyn Any {
57        self
58    }
59
60    fn name(&self) -> &str {
61        "unhex"
62    }
63
64    fn signature(&self) -> &Signature {
65        &self.signature
66    }
67
68    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
69        Ok(DataType::Binary)
70    }
71
72    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
73        spark_unhex(&args.args)
74    }
75}
76
77#[inline]
78fn hex_nibble(c: u8) -> Option<u8> {
79    match c {
80        b'0'..=b'9' => Some(c - b'0'),
81        b'a'..=b'f' => Some(c - b'a' + 10),
82        b'A'..=b'F' => Some(c - b'A' + 10),
83        _ => None,
84    }
85}
86
87/// Decodes a hex-encoded byte slice into binary data.
88/// Returns `true` if decoding succeeded, `false` if the input contains invalid hex characters.
89fn unhex_common(bytes: &[u8], out: &mut Vec<u8>) -> bool {
90    if bytes.is_empty() {
91        return true;
92    }
93
94    let mut i = 0usize;
95
96    // If the hex string length is odd, implicitly left-pad with '0'.
97    if (bytes.len() & 1) == 1 {
98        match hex_nibble(bytes[0]) {
99            // Equivalent to (0 << 4) | lo
100            Some(lo) => out.push(lo),
101            None => return false,
102        }
103        i = 1;
104    }
105
106    while i + 1 < bytes.len() {
107        match (hex_nibble(bytes[i]), hex_nibble(bytes[i + 1])) {
108            (Some(hi), Some(lo)) => out.push((hi << 4) | lo),
109            _ => return false,
110        }
111        i += 2;
112    }
113
114    true
115}
116
117/// Converts an iterator of hex strings to a binary array.
118fn unhex_array<I, T>(
119    iter: I,
120    len: usize,
121    capacity: usize,
122) -> Result<ArrayRef, DataFusionError>
123where
124    I: Iterator<Item = Option<T>>,
125    T: AsRef<str>,
126{
127    let mut builder = BinaryBuilder::with_capacity(len, capacity);
128    let mut buffer = Vec::new();
129
130    for v in iter {
131        if let Some(s) = v {
132            buffer.clear();
133            buffer.reserve(s.as_ref().len().div_ceil(2));
134            if unhex_common(s.as_ref().as_bytes(), &mut buffer) {
135                builder.append_value(&buffer);
136            } else {
137                builder.append_null();
138            }
139        } else {
140            builder.append_null();
141        }
142    }
143
144    Ok(Arc::new(builder.finish()))
145}
146
147/// Convert a single hex string to binary
148fn unhex_scalar(s: &str) -> Option<Vec<u8>> {
149    let mut buffer = Vec::with_capacity(s.len().div_ceil(2));
150    if unhex_common(s.as_bytes(), &mut buffer) {
151        Some(buffer)
152    } else {
153        None
154    }
155}
156
157fn spark_unhex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
158    let [args] = take_function_args("unhex", args)?;
159
160    match args {
161        ColumnarValue::Array(array) => match array.data_type() {
162            DataType::Utf8 => {
163                let array = as_string_array(array)?;
164                let capacity = array.values().len().div_ceil(2);
165                Ok(ColumnarValue::Array(unhex_array(
166                    array.iter(),
167                    array.len(),
168                    capacity,
169                )?))
170            }
171            DataType::Utf8View => {
172                let array = as_string_view_array(array)?;
173                // Estimate capacity since StringViewArray data can be scattered or inlined.
174                let capacity = array.len() * 32;
175                Ok(ColumnarValue::Array(unhex_array(
176                    array.iter(),
177                    array.len(),
178                    capacity,
179                )?))
180            }
181            DataType::LargeUtf8 => {
182                let array = as_large_string_array(array)?;
183                let capacity = array.values().len().div_ceil(2);
184                Ok(ColumnarValue::Array(unhex_array(
185                    array.iter(),
186                    array.len(),
187                    capacity,
188                )?))
189            }
190            _ => exec_err!(
191                "unhex only supports string argument, but got: {}",
192                array.data_type()
193            ),
194        },
195        ColumnarValue::Scalar(sv) => match sv {
196            ScalarValue::Utf8(None)
197            | ScalarValue::Utf8View(None)
198            | ScalarValue::LargeUtf8(None) => {
199                Ok(ColumnarValue::Scalar(ScalarValue::Binary(None)))
200            }
201            ScalarValue::Utf8(Some(s))
202            | ScalarValue::Utf8View(Some(s))
203            | ScalarValue::LargeUtf8(Some(s)) => {
204                Ok(ColumnarValue::Scalar(ScalarValue::Binary(unhex_scalar(s))))
205            }
206            _ => {
207                exec_err!(
208                    "unhex only supports string argument, but got: {}",
209                    sv.data_type()
210                )
211            }
212        },
213    }
214}