datafusion_comet_spark_expr/hash_funcs/
sha2.rs1use crate::math_funcs::hex::hex_strings;
19use arrow_array::{Array, StringArray};
20use arrow_schema::DataType;
21use datafusion::functions::crypto::{sha224, sha256, sha384, sha512};
22use datafusion_common::cast::as_binary_array;
23use datafusion_common::{exec_err, DataFusionError, ScalarValue};
24use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF};
25use std::sync::Arc;
26
27pub fn spark_sha224(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
29 wrap_digest_result_as_hex_string(args, sha224())
30}
31
32pub fn spark_sha256(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
34 wrap_digest_result_as_hex_string(args, sha256())
35}
36
37pub fn spark_sha384(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
39 wrap_digest_result_as_hex_string(args, sha384())
40}
41
42pub fn spark_sha512(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
44 wrap_digest_result_as_hex_string(args, sha512())
45}
46
47fn wrap_digest_result_as_hex_string(
50 args: &[ColumnarValue],
51 digest: Arc<ScalarUDF>,
52) -> Result<ColumnarValue, DataFusionError> {
53 let row_count = match &args[0] {
54 ColumnarValue::Array(array) => array.len(),
55 ColumnarValue::Scalar(_) => 1,
56 };
57 let value = digest.invoke_with_args(ScalarFunctionArgs {
58 args: args.into(),
59 number_rows: row_count,
60 return_type: &DataType::Utf8,
61 })?;
62 match value {
63 ColumnarValue::Array(array) => {
64 let binary_array = as_binary_array(&array)?;
65 let string_array: StringArray = binary_array
66 .iter()
67 .map(|opt| opt.map(hex_strings::<_>))
68 .collect();
69 Ok(ColumnarValue::Array(Arc::new(string_array)))
70 }
71 ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar(
72 ScalarValue::Utf8(opt.map(hex_strings::<_>)),
73 )),
74 _ => {
75 exec_err!(
76 "digest function should return binary value, but got: {:?}",
77 value.data_type()
78 )
79 }
80 }
81}