datafusion_spark/function/hash/
sha2.rs1extern crate datafusion_functions;
19
20use crate::function::error_utils::{
21 invalid_arg_count_exec_err, unsupported_data_type_exec_err,
22};
23use crate::function::math::hex::spark_sha2_hex;
24use arrow::array::{ArrayRef, AsArray, StringArray};
25use arrow::datatypes::{DataType, Int32Type};
26use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
27use datafusion_expr::Signature;
28use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
29pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512};
30use std::any::Any;
31use std::sync::Arc;
32
33#[derive(Debug, PartialEq, Eq, Hash)]
35pub struct SparkSha2 {
36 signature: Signature,
37 aliases: Vec<String>,
38}
39
40impl Default for SparkSha2 {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46impl SparkSha2 {
47 pub fn new() -> Self {
48 Self {
49 signature: Signature::user_defined(Volatility::Immutable),
50 aliases: vec![],
51 }
52 }
53}
54
55impl ScalarUDFImpl for SparkSha2 {
56 fn as_any(&self) -> &dyn Any {
57 self
58 }
59
60 fn name(&self) -> &str {
61 "sha2"
62 }
63
64 fn signature(&self) -> &Signature {
65 &self.signature
66 }
67
68 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
69 if arg_types[1].is_null() {
70 return Ok(DataType::Null);
71 }
72 Ok(match arg_types[0] {
73 DataType::Utf8View
74 | DataType::LargeUtf8
75 | DataType::Utf8
76 | DataType::Binary
77 | DataType::BinaryView
78 | DataType::LargeBinary => DataType::Utf8,
79 DataType::Null => DataType::Null,
80 _ => {
81 return exec_err!(
82 "{} function can only accept strings or binary arrays.",
83 self.name()
84 )
85 }
86 })
87 }
88
89 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
90 let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
91 internal_datafusion_err!("Expected 2 arguments for function sha2")
92 })?;
93
94 sha2(args)
95 }
96
97 fn aliases(&self) -> &[String] {
98 &self.aliases
99 }
100
101 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
102 if arg_types.len() != 2 {
103 return Err(invalid_arg_count_exec_err(
104 self.name(),
105 (2, 2),
106 arg_types.len(),
107 ));
108 }
109 let expr_type = match &arg_types[0] {
110 DataType::Utf8View
111 | DataType::LargeUtf8
112 | DataType::Utf8
113 | DataType::Binary
114 | DataType::BinaryView
115 | DataType::LargeBinary
116 | DataType::Null => Ok(arg_types[0].clone()),
117 _ => Err(unsupported_data_type_exec_err(
118 self.name(),
119 "String, Binary",
120 &arg_types[0],
121 )),
122 }?;
123 let bit_length_type = if arg_types[1].is_numeric() {
124 Ok(DataType::Int32)
125 } else if arg_types[1].is_null() {
126 Ok(DataType::Null)
127 } else {
128 Err(unsupported_data_type_exec_err(
129 self.name(),
130 "Numeric Type",
131 &arg_types[1],
132 ))
133 }?;
134
135 Ok(vec![expr_type, bit_length_type])
136 }
137}
138
139pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
140 match args {
141 [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => {
142 compute_sha2(
143 bit_length_arg,
144 &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))],
145 )
146 }
147 [ColumnarValue::Array(expr_arg), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => {
148 compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)])
149 }
150 [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Array(bit_length_arg)] =>
151 {
152 let arr: StringArray = bit_length_arg
153 .as_primitive::<Int32Type>()
154 .iter()
155 .map(|bit_length| {
156 match sha2([
157 ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())),
158 ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
159 ])
160 .unwrap()
161 {
162 ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
163 ColumnarValue::Array(arr) => arr
164 .as_string::<i32>()
165 .iter()
166 .map(|str| str.unwrap().to_string())
167 .next(), _ => unreachable!(),
169 }
170 })
171 .collect();
172 Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
173 }
174 [ColumnarValue::Array(expr_arg), ColumnarValue::Array(bit_length_arg)] => {
175 let expr_iter = expr_arg.as_string::<i32>().iter();
176 let bit_length_iter = bit_length_arg.as_primitive::<Int32Type>().iter();
177 let arr: StringArray = expr_iter
178 .zip(bit_length_iter)
179 .map(|(expr, bit_length)| {
180 match sha2([
181 ColumnarValue::Scalar(ScalarValue::Utf8(Some(
182 expr.unwrap().to_string(),
183 ))),
184 ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
185 ])
186 .unwrap()
187 {
188 ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
189 ColumnarValue::Array(arr) => arr
190 .as_string::<i32>()
191 .iter()
192 .map(|str| str.unwrap().to_string())
193 .next(), _ => unreachable!(),
195 }
196 })
197 .collect();
198 Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
199 }
200 _ => exec_err!("Unsupported argument types for sha2 function"),
201 }
202}
203
204fn compute_sha2(
205 bit_length_arg: i32,
206 expr_arg: &[ColumnarValue],
207) -> Result<ColumnarValue> {
208 match bit_length_arg {
209 0 | 256 => sha256(expr_arg),
210 224 => sha224(expr_arg),
211 384 => sha384(expr_arg),
212 512 => sha512(expr_arg),
213 _ => {
214 return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
217 }
218 }
219 .map(|hashed| spark_sha2_hex(&[hashed]).unwrap())
220}