datafusion_spark/function/hash/
sha2.rs1use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray};
19use arrow::datatypes::{DataType, Int32Type};
20use datafusion_common::types::{
21 NativeType, logical_binary, logical_int32, logical_string,
22};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{Result, ScalarValue, internal_err};
25use datafusion_expr::{
26 Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27 TypeSignatureClass, Volatility,
28};
29use datafusion_functions::utils::make_scalar_function;
30use sha2::{self, Digest};
31use std::sync::Arc;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
38pub struct SparkSha2 {
39 signature: Signature,
40}
41
42impl Default for SparkSha2 {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl SparkSha2 {
49 pub fn new() -> Self {
50 Self {
51 signature: Signature::coercible(
52 vec![
53 Coercion::new_implicit(
54 TypeSignatureClass::Native(logical_binary()),
55 vec![TypeSignatureClass::Native(logical_string())],
56 NativeType::Binary,
57 ),
58 Coercion::new_implicit(
59 TypeSignatureClass::Native(logical_int32()),
60 vec![TypeSignatureClass::Integer],
61 NativeType::Int32,
62 ),
63 ],
64 Volatility::Immutable,
65 ),
66 }
67 }
68}
69
70impl ScalarUDFImpl for SparkSha2 {
71 fn name(&self) -> &str {
72 "sha2"
73 }
74
75 fn signature(&self) -> &Signature {
76 &self.signature
77 }
78
79 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80 Ok(DataType::Utf8)
81 }
82
83 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
84 let [values, bit_lengths] = take_function_args(self.name(), args.args.iter())?;
85
86 match (values, bit_lengths) {
87 (
88 ColumnarValue::Scalar(value_scalar),
89 ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))),
90 ) => {
91 if value_scalar.is_null() {
92 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
93 }
94
95 let bytes = match value_scalar {
97 ScalarValue::Binary(Some(b)) => b.as_slice(),
98 ScalarValue::LargeBinary(Some(b)) => b.as_slice(),
99 ScalarValue::BinaryView(Some(b)) => b.as_slice(),
100 ScalarValue::Utf8(Some(s))
101 | ScalarValue::LargeUtf8(Some(s))
102 | ScalarValue::Utf8View(Some(s)) => s.as_bytes(),
103 other => {
104 return internal_err!(
105 "Unsupported scalar datatype for sha2: {}",
106 other.data_type()
107 );
108 }
109 };
110
111 let out = match bit_length {
112 224 => {
113 let mut digest = sha2::Sha224::default();
114 digest.update(bytes);
115 Some(hex_encode(digest.finalize()))
116 }
117 0 | 256 => {
118 let mut digest = sha2::Sha256::default();
119 digest.update(bytes);
120 Some(hex_encode(digest.finalize()))
121 }
122 384 => {
123 let mut digest = sha2::Sha384::default();
124 digest.update(bytes);
125 Some(hex_encode(digest.finalize()))
126 }
127 512 => {
128 let mut digest = sha2::Sha512::default();
129 digest.update(bytes);
130 Some(hex_encode(digest.finalize()))
131 }
132 _ => None,
133 };
134
135 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(out)))
136 }
137 (
139 ColumnarValue::Array(values_array),
140 ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))),
141 ) => {
142 let output: ArrayRef = match values_array.data_type() {
143 DataType::Binary => sha2_binary_scalar_bitlen(
144 &values_array.as_binary::<i32>(),
145 *bit_length,
146 ),
147 DataType::LargeBinary => sha2_binary_scalar_bitlen(
148 &values_array.as_binary::<i64>(),
149 *bit_length,
150 ),
151 DataType::BinaryView => sha2_binary_scalar_bitlen(
152 &values_array.as_binary_view(),
153 *bit_length,
154 ),
155 dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
156 };
157 Ok(ColumnarValue::Array(output))
158 }
159 (
160 ColumnarValue::Scalar(_),
161 ColumnarValue::Scalar(ScalarValue::Int32(None)),
162 ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
163 (
164 ColumnarValue::Array(_),
165 ColumnarValue::Scalar(ScalarValue::Int32(None)),
166 ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
167 _ => {
168 make_scalar_function(sha2_impl, vec![])(&args.args)
170 }
171 }
172 }
173}
174
175fn sha2_impl(args: &[ArrayRef]) -> Result<ArrayRef> {
176 let [values, bit_lengths] = take_function_args("sha2", args)?;
177
178 let bit_lengths = bit_lengths.as_primitive::<Int32Type>();
179 let output = match values.data_type() {
180 DataType::Binary => sha2_binary_impl(&values.as_binary::<i32>(), bit_lengths),
181 DataType::LargeBinary => {
182 sha2_binary_impl(&values.as_binary::<i64>(), bit_lengths)
183 }
184 DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths),
185 dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
186 };
187 Ok(output)
188}
189
190fn sha2_binary_impl<'a, BinaryArrType>(
191 values: &BinaryArrType,
192 bit_lengths: &Int32Array,
193) -> ArrayRef
194where
195 BinaryArrType: BinaryArrayType<'a>,
196{
197 sha2_binary_bitlen_iter(values, bit_lengths.iter())
198}
199
200fn sha2_binary_scalar_bitlen<'a, BinaryArrType>(
201 values: &BinaryArrType,
202 bit_length: i32,
203) -> ArrayRef
204where
205 BinaryArrType: BinaryArrayType<'a>,
206{
207 sha2_binary_bitlen_iter(values, std::iter::repeat(Some(bit_length)))
208}
209
210fn sha2_binary_bitlen_iter<'a, BinaryArrType, I>(
211 values: &BinaryArrType,
212 bit_lengths: I,
213) -> ArrayRef
214where
215 BinaryArrType: BinaryArrayType<'a>,
216 I: Iterator<Item = Option<i32>>,
217{
218 let array = values
219 .iter()
220 .zip(bit_lengths)
221 .map(|(value, bit_length)| match (value, bit_length) {
222 (Some(value), Some(224)) => {
223 let mut digest = sha2::Sha224::default();
224 digest.update(value);
225 Some(hex_encode(digest.finalize()))
226 }
227 (Some(value), Some(0 | 256)) => {
228 let mut digest = sha2::Sha256::default();
229 digest.update(value);
230 Some(hex_encode(digest.finalize()))
231 }
232 (Some(value), Some(384)) => {
233 let mut digest = sha2::Sha384::default();
234 digest.update(value);
235 Some(hex_encode(digest.finalize()))
236 }
237 (Some(value), Some(512)) => {
238 let mut digest = sha2::Sha512::default();
239 digest.update(value);
240 Some(hex_encode(digest.finalize()))
241 }
242 _ => None,
244 })
245 .collect::<StringArray>();
246 Arc::new(array)
247}
248
249const HEX_CHARS: [u8; 16] = *b"0123456789abcdef";
250
251#[inline]
252fn hex_encode<T: AsRef<[u8]>>(data: T) -> String {
253 let bytes = data.as_ref();
254 let mut out = Vec::with_capacity(bytes.len() * 2);
255 for &b in bytes {
256 let hi = b >> 4;
257 let lo = b & 0x0F;
258 out.push(HEX_CHARS[hi as usize]);
259 out.push(HEX_CHARS[lo as usize]);
260 }
261 unsafe { String::from_utf8_unchecked(out) }
263}