Skip to main content

datafusion_spark/function/hash/
sha2.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray};
19use arrow::datatypes::{DataType, Int32Type};
20use datafusion_common::types::{
21    NativeType, logical_binary, logical_int32, logical_string,
22};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{Result, ScalarValue, internal_err};
25use datafusion_expr::{
26    Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27    TypeSignatureClass, Volatility,
28};
29use datafusion_functions::utils::make_scalar_function;
30use sha2::{self, Digest};
31use std::any::Any;
32use std::sync::Arc;
33
34/// Differs from DataFusion version in allowing array input for bit lengths, and
35/// also hex encoding the output.
36///
37/// <https://spark.apache.org/docs/latest/api/sql/index.html#sha2>
38#[derive(Debug, PartialEq, Eq, Hash)]
39pub struct SparkSha2 {
40    signature: Signature,
41}
42
43impl Default for SparkSha2 {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl SparkSha2 {
50    pub fn new() -> Self {
51        Self {
52            signature: Signature::coercible(
53                vec![
54                    Coercion::new_implicit(
55                        TypeSignatureClass::Native(logical_binary()),
56                        vec![TypeSignatureClass::Native(logical_string())],
57                        NativeType::Binary,
58                    ),
59                    Coercion::new_implicit(
60                        TypeSignatureClass::Native(logical_int32()),
61                        vec![TypeSignatureClass::Integer],
62                        NativeType::Int32,
63                    ),
64                ],
65                Volatility::Immutable,
66            ),
67        }
68    }
69}
70
71impl ScalarUDFImpl for SparkSha2 {
72    fn as_any(&self) -> &dyn Any {
73        self
74    }
75
76    fn name(&self) -> &str {
77        "sha2"
78    }
79
80    fn signature(&self) -> &Signature {
81        &self.signature
82    }
83
84    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85        Ok(DataType::Utf8)
86    }
87
88    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
89        let [values, bit_lengths] = take_function_args(self.name(), args.args.iter())?;
90
91        match (values, bit_lengths) {
92            (
93                ColumnarValue::Scalar(value_scalar),
94                ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))),
95            ) => {
96                if value_scalar.is_null() {
97                    return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
98                }
99
100                // Accept both Binary and Utf8 scalars (depending on coercion)
101                let bytes = match value_scalar {
102                    ScalarValue::Binary(Some(b)) => b.as_slice(),
103                    ScalarValue::LargeBinary(Some(b)) => b.as_slice(),
104                    ScalarValue::BinaryView(Some(b)) => b.as_slice(),
105                    ScalarValue::Utf8(Some(s))
106                    | ScalarValue::LargeUtf8(Some(s))
107                    | ScalarValue::Utf8View(Some(s)) => s.as_bytes(),
108                    other => {
109                        return internal_err!(
110                            "Unsupported scalar datatype for sha2: {}",
111                            other.data_type()
112                        );
113                    }
114                };
115
116                let out = match bit_length {
117                    224 => {
118                        let mut digest = sha2::Sha224::default();
119                        digest.update(bytes);
120                        Some(hex_encode(digest.finalize()))
121                    }
122                    0 | 256 => {
123                        let mut digest = sha2::Sha256::default();
124                        digest.update(bytes);
125                        Some(hex_encode(digest.finalize()))
126                    }
127                    384 => {
128                        let mut digest = sha2::Sha384::default();
129                        digest.update(bytes);
130                        Some(hex_encode(digest.finalize()))
131                    }
132                    512 => {
133                        let mut digest = sha2::Sha512::default();
134                        digest.update(bytes);
135                        Some(hex_encode(digest.finalize()))
136                    }
137                    _ => None,
138                };
139
140                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(out)))
141            }
142            // Array values + scalar bit length (common case: sha2(col, 256))
143            (
144                ColumnarValue::Array(values_array),
145                ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))),
146            ) => {
147                let output: ArrayRef = match values_array.data_type() {
148                    DataType::Binary => sha2_binary_scalar_bitlen(
149                        &values_array.as_binary::<i32>(),
150                        *bit_length,
151                    ),
152                    DataType::LargeBinary => sha2_binary_scalar_bitlen(
153                        &values_array.as_binary::<i64>(),
154                        *bit_length,
155                    ),
156                    DataType::BinaryView => sha2_binary_scalar_bitlen(
157                        &values_array.as_binary_view(),
158                        *bit_length,
159                    ),
160                    dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
161                };
162                Ok(ColumnarValue::Array(output))
163            }
164            (
165                ColumnarValue::Scalar(_),
166                ColumnarValue::Scalar(ScalarValue::Int32(None)),
167            ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
168            (
169                ColumnarValue::Array(_),
170                ColumnarValue::Scalar(ScalarValue::Int32(None)),
171            ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
172            _ => {
173                // Fallback to existing behavior for any array/mixed cases
174                make_scalar_function(sha2_impl, vec![])(&args.args)
175            }
176        }
177    }
178}
179
180fn sha2_impl(args: &[ArrayRef]) -> Result<ArrayRef> {
181    let [values, bit_lengths] = take_function_args("sha2", args)?;
182
183    let bit_lengths = bit_lengths.as_primitive::<Int32Type>();
184    let output = match values.data_type() {
185        DataType::Binary => sha2_binary_impl(&values.as_binary::<i32>(), bit_lengths),
186        DataType::LargeBinary => {
187            sha2_binary_impl(&values.as_binary::<i64>(), bit_lengths)
188        }
189        DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths),
190        dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
191    };
192    Ok(output)
193}
194
195fn sha2_binary_impl<'a, BinaryArrType>(
196    values: &BinaryArrType,
197    bit_lengths: &Int32Array,
198) -> ArrayRef
199where
200    BinaryArrType: BinaryArrayType<'a>,
201{
202    sha2_binary_bitlen_iter(values, bit_lengths.iter())
203}
204
205fn sha2_binary_scalar_bitlen<'a, BinaryArrType>(
206    values: &BinaryArrType,
207    bit_length: i32,
208) -> ArrayRef
209where
210    BinaryArrType: BinaryArrayType<'a>,
211{
212    sha2_binary_bitlen_iter(values, std::iter::repeat(Some(bit_length)))
213}
214
215fn sha2_binary_bitlen_iter<'a, BinaryArrType, I>(
216    values: &BinaryArrType,
217    bit_lengths: I,
218) -> ArrayRef
219where
220    BinaryArrType: BinaryArrayType<'a>,
221    I: Iterator<Item = Option<i32>>,
222{
223    let array = values
224        .iter()
225        .zip(bit_lengths)
226        .map(|(value, bit_length)| match (value, bit_length) {
227            (Some(value), Some(224)) => {
228                let mut digest = sha2::Sha224::default();
229                digest.update(value);
230                Some(hex_encode(digest.finalize()))
231            }
232            (Some(value), Some(0 | 256)) => {
233                let mut digest = sha2::Sha256::default();
234                digest.update(value);
235                Some(hex_encode(digest.finalize()))
236            }
237            (Some(value), Some(384)) => {
238                let mut digest = sha2::Sha384::default();
239                digest.update(value);
240                Some(hex_encode(digest.finalize()))
241            }
242            (Some(value), Some(512)) => {
243                let mut digest = sha2::Sha512::default();
244                digest.update(value);
245                Some(hex_encode(digest.finalize()))
246            }
247            // Unknown bit-lengths go to null, same as in Spark
248            _ => None,
249        })
250        .collect::<StringArray>();
251    Arc::new(array)
252}
253
254const HEX_CHARS: [u8; 16] = *b"0123456789abcdef";
255
256#[inline]
257fn hex_encode<T: AsRef<[u8]>>(data: T) -> String {
258    let bytes = data.as_ref();
259    let mut out = Vec::with_capacity(bytes.len() * 2);
260    for &b in bytes {
261        let hi = b >> 4;
262        let lo = b & 0x0F;
263        out.push(HEX_CHARS[hi as usize]);
264        out.push(HEX_CHARS[lo as usize]);
265    }
266    // SAFETY: out contains only ASCII
267    unsafe { String::from_utf8_unchecked(out) }
268}