datafusion_spark/function/hash/
sha2.rs1use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray};
19use arrow::datatypes::{DataType, Int32Type};
20use datafusion_common::types::{
21 NativeType, logical_binary, logical_int32, logical_string,
22};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{Result, ScalarValue, internal_err};
25use datafusion_expr::{
26 Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27 TypeSignatureClass, Volatility,
28};
29use datafusion_functions::utils::make_scalar_function;
30use sha2::{self, Digest};
31use std::any::Any;
32use std::sync::Arc;
33
34#[derive(Debug, PartialEq, Eq, Hash)]
39pub struct SparkSha2 {
40 signature: Signature,
41}
42
43impl Default for SparkSha2 {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl SparkSha2 {
50 pub fn new() -> Self {
51 Self {
52 signature: Signature::coercible(
53 vec![
54 Coercion::new_implicit(
55 TypeSignatureClass::Native(logical_binary()),
56 vec![TypeSignatureClass::Native(logical_string())],
57 NativeType::Binary,
58 ),
59 Coercion::new_implicit(
60 TypeSignatureClass::Native(logical_int32()),
61 vec![TypeSignatureClass::Integer],
62 NativeType::Int32,
63 ),
64 ],
65 Volatility::Immutable,
66 ),
67 }
68 }
69}
70
71impl ScalarUDFImpl for SparkSha2 {
72 fn as_any(&self) -> &dyn Any {
73 self
74 }
75
76 fn name(&self) -> &str {
77 "sha2"
78 }
79
80 fn signature(&self) -> &Signature {
81 &self.signature
82 }
83
84 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
85 Ok(DataType::Utf8)
86 }
87
88 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
89 let [values, bit_lengths] = take_function_args(self.name(), args.args.iter())?;
90
91 match (values, bit_lengths) {
92 (
93 ColumnarValue::Scalar(value_scalar),
94 ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))),
95 ) => {
96 if value_scalar.is_null() {
97 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
98 }
99
100 let bytes = match value_scalar {
102 ScalarValue::Binary(Some(b)) => b.as_slice(),
103 ScalarValue::LargeBinary(Some(b)) => b.as_slice(),
104 ScalarValue::BinaryView(Some(b)) => b.as_slice(),
105 ScalarValue::Utf8(Some(s))
106 | ScalarValue::LargeUtf8(Some(s))
107 | ScalarValue::Utf8View(Some(s)) => s.as_bytes(),
108 other => {
109 return internal_err!(
110 "Unsupported scalar datatype for sha2: {}",
111 other.data_type()
112 );
113 }
114 };
115
116 let out = match bit_length {
117 224 => {
118 let mut digest = sha2::Sha224::default();
119 digest.update(bytes);
120 Some(hex_encode(digest.finalize()))
121 }
122 0 | 256 => {
123 let mut digest = sha2::Sha256::default();
124 digest.update(bytes);
125 Some(hex_encode(digest.finalize()))
126 }
127 384 => {
128 let mut digest = sha2::Sha384::default();
129 digest.update(bytes);
130 Some(hex_encode(digest.finalize()))
131 }
132 512 => {
133 let mut digest = sha2::Sha512::default();
134 digest.update(bytes);
135 Some(hex_encode(digest.finalize()))
136 }
137 _ => None,
138 };
139
140 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(out)))
141 }
142 (
144 ColumnarValue::Array(values_array),
145 ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))),
146 ) => {
147 let output: ArrayRef = match values_array.data_type() {
148 DataType::Binary => sha2_binary_scalar_bitlen(
149 &values_array.as_binary::<i32>(),
150 *bit_length,
151 ),
152 DataType::LargeBinary => sha2_binary_scalar_bitlen(
153 &values_array.as_binary::<i64>(),
154 *bit_length,
155 ),
156 DataType::BinaryView => sha2_binary_scalar_bitlen(
157 &values_array.as_binary_view(),
158 *bit_length,
159 ),
160 dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
161 };
162 Ok(ColumnarValue::Array(output))
163 }
164 (
165 ColumnarValue::Scalar(_),
166 ColumnarValue::Scalar(ScalarValue::Int32(None)),
167 ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
168 (
169 ColumnarValue::Array(_),
170 ColumnarValue::Scalar(ScalarValue::Int32(None)),
171 ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
172 _ => {
173 make_scalar_function(sha2_impl, vec![])(&args.args)
175 }
176 }
177 }
178}
179
180fn sha2_impl(args: &[ArrayRef]) -> Result<ArrayRef> {
181 let [values, bit_lengths] = take_function_args("sha2", args)?;
182
183 let bit_lengths = bit_lengths.as_primitive::<Int32Type>();
184 let output = match values.data_type() {
185 DataType::Binary => sha2_binary_impl(&values.as_binary::<i32>(), bit_lengths),
186 DataType::LargeBinary => {
187 sha2_binary_impl(&values.as_binary::<i64>(), bit_lengths)
188 }
189 DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths),
190 dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
191 };
192 Ok(output)
193}
194
195fn sha2_binary_impl<'a, BinaryArrType>(
196 values: &BinaryArrType,
197 bit_lengths: &Int32Array,
198) -> ArrayRef
199where
200 BinaryArrType: BinaryArrayType<'a>,
201{
202 sha2_binary_bitlen_iter(values, bit_lengths.iter())
203}
204
205fn sha2_binary_scalar_bitlen<'a, BinaryArrType>(
206 values: &BinaryArrType,
207 bit_length: i32,
208) -> ArrayRef
209where
210 BinaryArrType: BinaryArrayType<'a>,
211{
212 sha2_binary_bitlen_iter(values, std::iter::repeat(Some(bit_length)))
213}
214
215fn sha2_binary_bitlen_iter<'a, BinaryArrType, I>(
216 values: &BinaryArrType,
217 bit_lengths: I,
218) -> ArrayRef
219where
220 BinaryArrType: BinaryArrayType<'a>,
221 I: Iterator<Item = Option<i32>>,
222{
223 let array = values
224 .iter()
225 .zip(bit_lengths)
226 .map(|(value, bit_length)| match (value, bit_length) {
227 (Some(value), Some(224)) => {
228 let mut digest = sha2::Sha224::default();
229 digest.update(value);
230 Some(hex_encode(digest.finalize()))
231 }
232 (Some(value), Some(0 | 256)) => {
233 let mut digest = sha2::Sha256::default();
234 digest.update(value);
235 Some(hex_encode(digest.finalize()))
236 }
237 (Some(value), Some(384)) => {
238 let mut digest = sha2::Sha384::default();
239 digest.update(value);
240 Some(hex_encode(digest.finalize()))
241 }
242 (Some(value), Some(512)) => {
243 let mut digest = sha2::Sha512::default();
244 digest.update(value);
245 Some(hex_encode(digest.finalize()))
246 }
247 _ => None,
249 })
250 .collect::<StringArray>();
251 Arc::new(array)
252}
253
254const HEX_CHARS: [u8; 16] = *b"0123456789abcdef";
255
256#[inline]
257fn hex_encode<T: AsRef<[u8]>>(data: T) -> String {
258 let bytes = data.as_ref();
259 let mut out = Vec::with_capacity(bytes.len() * 2);
260 for &b in bytes {
261 let hi = b >> 4;
262 let lo = b & 0x0F;
263 out.push(HEX_CHARS[hi as usize]);
264 out.push(HEX_CHARS[lo as usize]);
265 }
266 unsafe { String::from_utf8_unchecked(out) }
268}