Skip to main content

datafusion_spark/function/hash/
sha2.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray};
19use arrow::datatypes::{DataType, Int32Type};
20use datafusion_common::types::{
21    NativeType, logical_binary, logical_int32, logical_string,
22};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{Result, ScalarValue, internal_err};
25use datafusion_expr::{
26    Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27    TypeSignatureClass, Volatility,
28};
29use datafusion_functions::utils::make_scalar_function;
30use sha2::{self, Digest};
31use std::sync::Arc;
32
33/// Differs from DataFusion version in allowing array input for bit lengths, and
34/// also hex encoding the output.
35///
36/// <https://spark.apache.org/docs/latest/api/sql/index.html#sha2>
37#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct SparkSha2 {
39    signature: Signature,
40}
41
42impl Default for SparkSha2 {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl SparkSha2 {
49    pub fn new() -> Self {
50        Self {
51            signature: Signature::coercible(
52                vec![
53                    Coercion::new_implicit(
54                        TypeSignatureClass::Native(logical_binary()),
55                        vec![TypeSignatureClass::Native(logical_string())],
56                        NativeType::Binary,
57                    ),
58                    Coercion::new_implicit(
59                        TypeSignatureClass::Native(logical_int32()),
60                        vec![TypeSignatureClass::Integer],
61                        NativeType::Int32,
62                    ),
63                ],
64                Volatility::Immutable,
65            ),
66        }
67    }
68}
69
70impl ScalarUDFImpl for SparkSha2 {
71    fn name(&self) -> &str {
72        "sha2"
73    }
74
75    fn signature(&self) -> &Signature {
76        &self.signature
77    }
78
79    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80        Ok(DataType::Utf8)
81    }
82
83    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
84        let [values, bit_lengths] = take_function_args(self.name(), args.args.iter())?;
85
86        match (values, bit_lengths) {
87            (
88                ColumnarValue::Scalar(value_scalar),
89                ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))),
90            ) => {
91                if value_scalar.is_null() {
92                    return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
93                }
94
95                // Accept both Binary and Utf8 scalars (depending on coercion)
96                let bytes = match value_scalar {
97                    ScalarValue::Binary(Some(b)) => b.as_slice(),
98                    ScalarValue::LargeBinary(Some(b)) => b.as_slice(),
99                    ScalarValue::BinaryView(Some(b)) => b.as_slice(),
100                    ScalarValue::Utf8(Some(s))
101                    | ScalarValue::LargeUtf8(Some(s))
102                    | ScalarValue::Utf8View(Some(s)) => s.as_bytes(),
103                    other => {
104                        return internal_err!(
105                            "Unsupported scalar datatype for sha2: {}",
106                            other.data_type()
107                        );
108                    }
109                };
110
111                let out = match bit_length {
112                    224 => {
113                        let mut digest = sha2::Sha224::default();
114                        digest.update(bytes);
115                        Some(hex_encode(digest.finalize()))
116                    }
117                    0 | 256 => {
118                        let mut digest = sha2::Sha256::default();
119                        digest.update(bytes);
120                        Some(hex_encode(digest.finalize()))
121                    }
122                    384 => {
123                        let mut digest = sha2::Sha384::default();
124                        digest.update(bytes);
125                        Some(hex_encode(digest.finalize()))
126                    }
127                    512 => {
128                        let mut digest = sha2::Sha512::default();
129                        digest.update(bytes);
130                        Some(hex_encode(digest.finalize()))
131                    }
132                    _ => None,
133                };
134
135                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(out)))
136            }
137            // Array values + scalar bit length (common case: sha2(col, 256))
138            (
139                ColumnarValue::Array(values_array),
140                ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))),
141            ) => {
142                let output: ArrayRef = match values_array.data_type() {
143                    DataType::Binary => sha2_binary_scalar_bitlen(
144                        &values_array.as_binary::<i32>(),
145                        *bit_length,
146                    ),
147                    DataType::LargeBinary => sha2_binary_scalar_bitlen(
148                        &values_array.as_binary::<i64>(),
149                        *bit_length,
150                    ),
151                    DataType::BinaryView => sha2_binary_scalar_bitlen(
152                        &values_array.as_binary_view(),
153                        *bit_length,
154                    ),
155                    dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
156                };
157                Ok(ColumnarValue::Array(output))
158            }
159            (
160                ColumnarValue::Scalar(_),
161                ColumnarValue::Scalar(ScalarValue::Int32(None)),
162            ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
163            (
164                ColumnarValue::Array(_),
165                ColumnarValue::Scalar(ScalarValue::Int32(None)),
166            ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
167            _ => {
168                // Fallback to existing behavior for any array/mixed cases
169                make_scalar_function(sha2_impl, vec![])(&args.args)
170            }
171        }
172    }
173}
174
175fn sha2_impl(args: &[ArrayRef]) -> Result<ArrayRef> {
176    let [values, bit_lengths] = take_function_args("sha2", args)?;
177
178    let bit_lengths = bit_lengths.as_primitive::<Int32Type>();
179    let output = match values.data_type() {
180        DataType::Binary => sha2_binary_impl(&values.as_binary::<i32>(), bit_lengths),
181        DataType::LargeBinary => {
182            sha2_binary_impl(&values.as_binary::<i64>(), bit_lengths)
183        }
184        DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths),
185        dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
186    };
187    Ok(output)
188}
189
190fn sha2_binary_impl<'a, BinaryArrType>(
191    values: &BinaryArrType,
192    bit_lengths: &Int32Array,
193) -> ArrayRef
194where
195    BinaryArrType: BinaryArrayType<'a>,
196{
197    sha2_binary_bitlen_iter(values, bit_lengths.iter())
198}
199
200fn sha2_binary_scalar_bitlen<'a, BinaryArrType>(
201    values: &BinaryArrType,
202    bit_length: i32,
203) -> ArrayRef
204where
205    BinaryArrType: BinaryArrayType<'a>,
206{
207    sha2_binary_bitlen_iter(values, std::iter::repeat(Some(bit_length)))
208}
209
210fn sha2_binary_bitlen_iter<'a, BinaryArrType, I>(
211    values: &BinaryArrType,
212    bit_lengths: I,
213) -> ArrayRef
214where
215    BinaryArrType: BinaryArrayType<'a>,
216    I: Iterator<Item = Option<i32>>,
217{
218    let array = values
219        .iter()
220        .zip(bit_lengths)
221        .map(|(value, bit_length)| match (value, bit_length) {
222            (Some(value), Some(224)) => {
223                let mut digest = sha2::Sha224::default();
224                digest.update(value);
225                Some(hex_encode(digest.finalize()))
226            }
227            (Some(value), Some(0 | 256)) => {
228                let mut digest = sha2::Sha256::default();
229                digest.update(value);
230                Some(hex_encode(digest.finalize()))
231            }
232            (Some(value), Some(384)) => {
233                let mut digest = sha2::Sha384::default();
234                digest.update(value);
235                Some(hex_encode(digest.finalize()))
236            }
237            (Some(value), Some(512)) => {
238                let mut digest = sha2::Sha512::default();
239                digest.update(value);
240                Some(hex_encode(digest.finalize()))
241            }
242            // Unknown bit-lengths go to null, same as in Spark
243            _ => None,
244        })
245        .collect::<StringArray>();
246    Arc::new(array)
247}
248
249const HEX_CHARS: [u8; 16] = *b"0123456789abcdef";
250
251#[inline]
252fn hex_encode<T: AsRef<[u8]>>(data: T) -> String {
253    let bytes = data.as_ref();
254    let mut out = Vec::with_capacity(bytes.len() * 2);
255    for &b in bytes {
256        let hi = b >> 4;
257        let lo = b & 0x0F;
258        out.push(HEX_CHARS[hi as usize]);
259        out.push(HEX_CHARS[lo as usize]);
260    }
261    // SAFETY: out contains only ASCII
262    unsafe { String::from_utf8_unchecked(out) }
263}