datafusion_spark/function/hash/
sha2.rs1extern crate datafusion_functions;
19
20use crate::function::error_utils::{
21 invalid_arg_count_exec_err, unsupported_data_type_exec_err,
22};
23use crate::function::math::hex::spark_sha2_hex;
24use arrow::array::{ArrayRef, AsArray, StringArray};
25use arrow::datatypes::{DataType, Int32Type};
26use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err};
27use datafusion_expr::Signature;
28use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
29pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512};
30use std::any::Any;
31use std::sync::Arc;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkSha2 {
36 signature: Signature,
37 aliases: Vec<String>,
38}
39
40impl Default for SparkSha2 {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl SparkSha2 {
47 pub fn new() -> Self {
48 Self {
49 signature: Signature::user_defined(Volatility::Immutable),
50 aliases: vec![],
51 }
52 }
53}
54
55impl ScalarUDFImpl for SparkSha2 {
56 fn as_any(&self) -> &dyn Any {
57 self
58 }
59
60 fn name(&self) -> &str {
61 "sha2"
62 }
63
64 fn signature(&self) -> &Signature {
65 &self.signature
66 }
67
68 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
69 if arg_types[1].is_null() {
70 return Ok(DataType::Null);
71 }
72 Ok(match arg_types[0] {
73 DataType::Utf8View
74 | DataType::LargeUtf8
75 | DataType::Utf8
76 | DataType::Binary
77 | DataType::BinaryView
78 | DataType::LargeBinary => DataType::Utf8,
79 DataType::Null => DataType::Null,
80 _ => {
81 return exec_err!(
82 "{} function can only accept strings or binary arrays.",
83 self.name()
84 );
85 }
86 })
87 }
88
89 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
90 let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
91 internal_datafusion_err!("Expected 2 arguments for function sha2")
92 })?;
93
94 sha2(args)
95 }
96
97 fn aliases(&self) -> &[String] {
98 &self.aliases
99 }
100
101 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
102 if arg_types.len() != 2 {
103 return Err(invalid_arg_count_exec_err(
104 self.name(),
105 (2, 2),
106 arg_types.len(),
107 ));
108 }
109 let expr_type = match &arg_types[0] {
110 DataType::Utf8View
111 | DataType::LargeUtf8
112 | DataType::Utf8
113 | DataType::Binary
114 | DataType::BinaryView
115 | DataType::LargeBinary
116 | DataType::Null => Ok(arg_types[0].clone()),
117 _ => Err(unsupported_data_type_exec_err(
118 self.name(),
119 "String, Binary",
120 &arg_types[0],
121 )),
122 }?;
123 let bit_length_type = if arg_types[1].is_numeric() {
124 Ok(DataType::Int32)
125 } else if arg_types[1].is_null() {
126 Ok(DataType::Null)
127 } else {
128 Err(unsupported_data_type_exec_err(
129 self.name(),
130 "Numeric Type",
131 &arg_types[1],
132 ))
133 }?;
134
135 Ok(vec![expr_type, bit_length_type])
136 }
137}
138
139pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
140 match args {
141 [
142 ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)),
143 ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))),
144 ] => compute_sha2(
145 bit_length_arg,
146 &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))],
147 ),
148 [
149 ColumnarValue::Array(expr_arg),
150 ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))),
151 ] => compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]),
152 [
153 ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)),
154 ColumnarValue::Array(bit_length_arg),
155 ] => {
156 let arr: StringArray = bit_length_arg
157 .as_primitive::<Int32Type>()
158 .iter()
159 .map(|bit_length| {
160 match sha2([
161 ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())),
162 ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
163 ])
164 .unwrap()
165 {
166 ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
167 ColumnarValue::Array(arr) => arr
168 .as_string::<i32>()
169 .iter()
170 .map(|str| str.unwrap().to_string())
171 .next(), _ => unreachable!(),
173 }
174 })
175 .collect();
176 Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
177 }
178 [
179 ColumnarValue::Array(expr_arg),
180 ColumnarValue::Array(bit_length_arg),
181 ] => {
182 let expr_iter = expr_arg.as_string::<i32>().iter();
183 let bit_length_iter = bit_length_arg.as_primitive::<Int32Type>().iter();
184 let arr: StringArray = expr_iter
185 .zip(bit_length_iter)
186 .map(|(expr, bit_length)| {
187 match sha2([
188 ColumnarValue::Scalar(ScalarValue::Utf8(Some(
189 expr.unwrap().to_string(),
190 ))),
191 ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
192 ])
193 .unwrap()
194 {
195 ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
196 ColumnarValue::Array(arr) => arr
197 .as_string::<i32>()
198 .iter()
199 .map(|str| str.unwrap().to_string())
200 .next(), _ => unreachable!(),
202 }
203 })
204 .collect();
205 Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
206 }
207 _ => exec_err!("Unsupported argument types for sha2 function"),
208 }
209}
210
211fn compute_sha2(
212 bit_length_arg: i32,
213 expr_arg: &[ColumnarValue],
214) -> Result<ColumnarValue> {
215 match bit_length_arg {
216 0 | 256 => sha256(expr_arg),
217 224 => sha224(expr_arg),
218 384 => sha384(expr_arg),
219 512 => sha512(expr_arg),
220 _ => {
221 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
224 }
225 }
226 .map(|hashed| spark_sha2_hex(&[hashed]).unwrap())
227}